Realcat commited on
Commit
1b369eb
·
1 Parent(s): 20ee7b7

add: rdd sparse and dense match

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +1 -1
  2. README.md +2 -1
  3. config/config.yaml +21 -0
  4. imcui/hloc/extract_features.py +11 -0
  5. imcui/hloc/extractors/liftfeat.py +3 -8
  6. imcui/hloc/extractors/rdd.py +56 -0
  7. imcui/hloc/match_dense.py +17 -0
  8. imcui/hloc/matchers/rdd_dense.py +52 -0
  9. imcui/third_party/rdd/.gitignore +8 -0
  10. imcui/third_party/rdd/LICENSE +201 -0
  11. imcui/third_party/rdd/RDD/RDD.py +260 -0
  12. imcui/third_party/rdd/RDD/RDD_helper.py +179 -0
  13. imcui/third_party/rdd/RDD/dataset/__init__.py +0 -0
  14. imcui/third_party/rdd/RDD/dataset/megadepth/__init__.py +2 -0
  15. imcui/third_party/rdd/RDD/dataset/megadepth/megadepth.py +313 -0
  16. imcui/third_party/rdd/RDD/dataset/megadepth/megadepth_warper.py +75 -0
  17. imcui/third_party/rdd/RDD/dataset/megadepth/utils.py +848 -0
  18. imcui/third_party/rdd/RDD/matchers/__init__.py +3 -0
  19. imcui/third_party/rdd/RDD/matchers/dense_matcher.py +88 -0
  20. imcui/third_party/rdd/RDD/matchers/dual_softmax_matcher.py +31 -0
  21. imcui/third_party/rdd/RDD/matchers/lightglue.py +667 -0
  22. imcui/third_party/rdd/RDD/models/backbone.py +147 -0
  23. imcui/third_party/rdd/RDD/models/deformable_transformer.py +270 -0
  24. imcui/third_party/rdd/RDD/models/descriptor.py +116 -0
  25. imcui/third_party/rdd/RDD/models/detector.py +141 -0
  26. imcui/third_party/rdd/RDD/models/interpolator.py +33 -0
  27. imcui/third_party/rdd/RDD/models/ops/functions/__init__.py +13 -0
  28. imcui/third_party/rdd/RDD/models/ops/functions/ms_deform_attn_func.py +72 -0
  29. imcui/third_party/rdd/RDD/models/ops/make.sh +13 -0
  30. imcui/third_party/rdd/RDD/models/ops/modules/__init__.py +12 -0
  31. imcui/third_party/rdd/RDD/models/ops/modules/ms_deform_attn.py +125 -0
  32. imcui/third_party/rdd/RDD/models/ops/setup.py +78 -0
  33. imcui/third_party/rdd/RDD/models/ops/src/cpu/ms_deform_attn_cpu.cpp +46 -0
  34. imcui/third_party/rdd/RDD/models/ops/src/cpu/ms_deform_attn_cpu.h +38 -0
  35. imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_attn_cuda.cu +158 -0
  36. imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_attn_cuda.h +35 -0
  37. imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_im2col_cuda.cuh +1332 -0
  38. imcui/third_party/rdd/RDD/models/ops/src/ms_deform_attn.h +67 -0
  39. imcui/third_party/rdd/RDD/models/ops/src/vision.cpp +21 -0
  40. imcui/third_party/rdd/RDD/models/ops/test.py +92 -0
  41. imcui/third_party/rdd/RDD/models/position_encoding.py +48 -0
  42. imcui/third_party/rdd/RDD/models/soft_detect.py +250 -0
  43. imcui/third_party/rdd/RDD/utils/__init__.py +1 -0
  44. imcui/third_party/rdd/RDD/utils/misc.py +531 -0
  45. imcui/third_party/rdd/README.md +197 -0
  46. imcui/third_party/rdd/benchmarks/air_ground.py +247 -0
  47. imcui/third_party/rdd/benchmarks/mega_1500.py +255 -0
  48. imcui/third_party/rdd/benchmarks/mega_view.py +250 -0
  49. imcui/third_party/rdd/benchmarks/utils.py +112 -0
  50. imcui/third_party/rdd/configs/default.yaml +19 -0
.gitignore CHANGED
@@ -7,7 +7,7 @@ cmake-build-debug/
7
  *.pyc
8
  flagged
9
  .ipynb_checkpoints
10
- __pycache__
11
  Untitled*
12
  experiments
13
  third_party/REKD
 
7
  *.pyc
8
  flagged
9
  .ipynb_checkpoints
10
+ **__pycache__**
11
  Untitled*
12
  experiments
13
  third_party/REKD
README.md CHANGED
@@ -44,8 +44,9 @@ The tool currently supports various popular image matching algorithms, namely:
44
 
45
  | Algorithm | Supported | Conference/Journal | Year | GitHub Link |
46
  |------------------|-----------|--------------------|------|-------------|
47
- | DaD | ✅ | ARXIV | 2025 | [Link](https://github.com/Parskatt/dad) |
48
  | LiftFeat | ✅ | ICRA | 2025 | [Link](https://github.com/lyp-deeplearning/LiftFeat) |
 
 
49
  | MINIMA | ✅ | ARXIV | 2024 | [Link](https://github.com/LSXI7/MINIMA) |
50
  | XoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/OnderT/XoFTR) |
51
  | EfficientLoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/zju3dv/EfficientLoFTR) |
 
44
 
45
  | Algorithm | Supported | Conference/Journal | Year | GitHub Link |
46
  |------------------|-----------|--------------------|------|-------------|
 
47
  | LiftFeat | ✅ | ICRA | 2025 | [Link](https://github.com/lyp-deeplearning/LiftFeat) |
48
+ | RDD | ✅ | CVPR | 2025 | [Link](https://github.com/xtcpete/rdd) |
49
+ | DaD | ✅ | ARXIV | 2025 | [Link](https://github.com/Parskatt/dad) |
50
  | MINIMA | ✅ | ARXIV | 2024 | [Link](https://github.com/LSXI7/MINIMA) |
51
  | XoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/OnderT/XoFTR) |
52
  | EfficientLoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/zju3dv/EfficientLoFTR) |
config/config.yaml CHANGED
@@ -267,6 +267,27 @@ matcher_zoo:
267
  paper: https://arxiv.org/abs/2505.0342
268
  project: null
269
  display: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  dedode:
271
  matcher: Dual-Softmax
272
  feature: dedode
 
267
  paper: https://arxiv.org/abs/2505.0342
268
  project: null
269
  display: true
270
+ rdd(sparse):
271
+ matcher: NN-mutual
272
+ feature: rdd
273
+ dense: false
274
+ info:
275
+ name: RDD(sparse) #dispaly name
276
+ source: "CVPR 2025"
277
+ github: hhttps://github.com/xtcpete/rdd
278
+ paper: https://arxiv.org/abs/2505.08013
279
+ project: https://xtcpete.github.io/rdd
280
+ display: true
281
+ rdd(dense):
282
+ matcher: rdd_dense
283
+ dense: true
284
+ info:
285
+ name: RDD(dense) #dispaly name
286
+ source: "CVPR 2025"
287
+ github: hhttps://github.com/xtcpete/rdd
288
+ paper: https://arxiv.org/abs/2505.08013
289
+ project: https://xtcpete.github.io/rdd
290
+ display: true
291
  dedode:
292
  matcher: Dual-Softmax
293
  feature: dedode
imcui/hloc/extract_features.py CHANGED
@@ -225,6 +225,17 @@ confs = {
225
  "resize_max": 1600,
226
  },
227
  },
 
 
 
 
 
 
 
 
 
 
 
228
  "aliked-n16-rot": {
229
  "output": "feats-aliked-n16-rot",
230
  "model": {
 
225
  "resize_max": 1600,
226
  },
227
  },
228
+ "rdd": {
229
+ "output": "feats-rdd-n5000-r1600",
230
+ "model": {
231
+ "name": "rdd",
232
+ "max_keypoints": 5000,
233
+ },
234
+ "preprocessing": {
235
+ "grayscale": False,
236
+ "resize_max": 1600,
237
+ },
238
+ },
239
  "aliked-n16-rot": {
240
  "output": "feats-aliked-n16-rot",
241
  "model": {
imcui/hloc/extractors/liftfeat.py CHANGED
@@ -1,13 +1,10 @@
1
- import logging
2
  import sys
3
  from pathlib import Path
4
- import torch
5
- import random
6
  from ..utils.base_model import BaseModel
7
  from .. import logger, MODEL_REPO_ID
8
 
9
- fire_path = Path(__file__).parent / "../../third_party/LiftFeat"
10
- sys.path.append(str(fire_path))
11
 
12
  from models.liftfeat_wrapper import LiftFeat
13
 
@@ -25,9 +22,7 @@ class Liftfeat(BaseModel):
25
  logger.info("Loading LiftFeat model...")
26
  model_path = self._download_model(
27
  repo_id=MODEL_REPO_ID,
28
- filename="{}/{}".format(
29
- Path(__file__).stem, self.conf["model_name"]
30
- ),
31
  )
32
  self.net = LiftFeat(
33
  weight=model_path,
 
 
1
  import sys
2
  from pathlib import Path
 
 
3
  from ..utils.base_model import BaseModel
4
  from .. import logger, MODEL_REPO_ID
5
 
6
+ liftfeat_path = Path(__file__).parent / "../../third_party/LiftFeat"
7
+ sys.path.append(str(liftfeat_path))
8
 
9
  from models.liftfeat_wrapper import LiftFeat
10
 
 
22
  logger.info("Loading LiftFeat model...")
23
  model_path = self._download_model(
24
  repo_id=MODEL_REPO_ID,
25
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
 
 
26
  )
27
  self.net = LiftFeat(
28
  weight=model_path,
imcui/hloc/extractors/rdd.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import yaml
3
+ from pathlib import Path
4
+ from ..utils.base_model import BaseModel
5
+ from .. import logger, MODEL_REPO_ID, DEVICE
6
+
7
+ rdd_path = Path(__file__).parent / "../../third_party/rdd"
8
+ sys.path.append(str(rdd_path))
9
+
10
+ from RDD.RDD import build as build_rdd
11
+
12
+ class Rdd(BaseModel):
13
+ default_conf = {
14
+ "keypoint_threshold": 0.1,
15
+ "max_keypoints": 4096,
16
+ "model_name": "RDD-v2.pth",
17
+ }
18
+
19
+ required_inputs = ["image"]
20
+
21
+ def _init(self, conf):
22
+ logger.info("Loading RDD model...")
23
+ model_path = self._download_model(
24
+ repo_id=MODEL_REPO_ID,
25
+ filename="{}/{}".format(
26
+ Path(__file__).stem, self.conf["model_name"]
27
+ ),
28
+ )
29
+ config_path = rdd_path / "configs/default.yaml"
30
+ with open(config_path, "r") as file:
31
+ config = yaml.safe_load(file)
32
+ config["top_k"] = conf["max_keypoints"]
33
+ config["detection_threshold"] = conf["keypoint_threshold"]
34
+ config["device"] = DEVICE
35
+ self.net = build_rdd(config=config, weights=model_path)
36
+ self.net.eval()
37
+ logger.info("Loading RDD model done!")
38
+
39
+ def _forward(self, data):
40
+ image = data["image"]
41
+ pred = self.net.extract(image)[0]
42
+ keypoints = pred["keypoints"]
43
+ descriptors = pred["descriptors"]
44
+ scores = pred["scores"]
45
+ if self.conf["max_keypoints"] < len(keypoints):
46
+ idxs = scores.argsort()[-self.conf["max_keypoints"] or None :]
47
+ keypoints = keypoints[idxs, :2]
48
+ descriptors = descriptors[idxs]
49
+ scores = scores[idxs]
50
+
51
+ pred = {
52
+ "keypoints": keypoints[None],
53
+ "descriptors": descriptors[None].permute(0, 2, 1),
54
+ "scores": scores[None],
55
+ }
56
+ return pred
imcui/hloc/match_dense.py CHANGED
@@ -337,6 +337,23 @@ confs = {
337
  "dfactor": 8,
338
  },
339
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  "minima_roma": {
341
  "output": "matches-minima_roma",
342
  "model": {
 
337
  "dfactor": 8,
338
  },
339
  },
340
+ "rdd_dense": {
341
+ "output": "matches-rdd_dense",
342
+ "model": {
343
+ "name": "rdd_dense",
344
+ "model_name": "RDD-v2.pth",
345
+ "max_keypoints": 2000,
346
+ "match_threshold": 0.2,
347
+ },
348
+ "preprocessing": {
349
+ "grayscale": False,
350
+ "force_resize": True,
351
+ "resize_max": 1024,
352
+ "width": 320,
353
+ "height": 240,
354
+ "dfactor": 8,
355
+ },
356
+ },
357
  "minima_roma": {
358
  "output": "matches-minima_roma",
359
  "model": {
imcui/hloc/matchers/rdd_dense.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import yaml
3
+ import torch
4
+ from pathlib import Path
5
+ from ..utils.base_model import BaseModel
6
+ from .. import logger, MODEL_REPO_ID, DEVICE
7
+
8
+ rdd_path = Path(__file__).parent / "../../third_party/rdd"
9
+ sys.path.append(str(rdd_path))
10
+
11
+ from RDD.RDD import build as build_rdd
12
+ from RDD.RDD_helper import RDD_helper
13
+
14
+ class RddDense(BaseModel):
15
+ default_conf = {
16
+ "keypoint_threshold": 0.1,
17
+ "max_keypoints": 4096,
18
+ "model_name": "RDD-v2.pth",
19
+ "match_threshold": 0.1,
20
+ }
21
+
22
+ required_inputs = ["image0", "image1"]
23
+
24
+ def _init(self, conf):
25
+ logger.info("Loading RDD model...")
26
+ model_path = self._download_model(
27
+ repo_id=MODEL_REPO_ID,
28
+ filename="{}/{}".format(
29
+ "rdd", self.conf["model_name"]
30
+ ),
31
+ )
32
+ config_path = rdd_path / "configs/default.yaml"
33
+ with open(config_path, "r") as file:
34
+ config = yaml.safe_load(file)
35
+ config["top_k"] = conf["max_keypoints"]
36
+ config["detection_threshold"] = conf["keypoint_threshold"]
37
+ config["device"] = DEVICE
38
+ rdd_net = build_rdd(config=config, weights=model_path)
39
+ rdd_net.eval()
40
+ self.net = RDD_helper(rdd_net)
41
+ logger.info("Loading RDD model done!")
42
+
43
+ def _forward(self, data):
44
+ img0 = data["image0"]
45
+ img1 = data["image1"]
46
+ mkpts_0, mkpts_1, conf = self.net.match_dense(img0, img1, thr=self.conf["match_threshold"])
47
+ pred = {
48
+ "keypoints0": torch.from_numpy(mkpts_0),
49
+ "keypoints1": torch.from_numpy(mkpts_1),
50
+ "mconf": torch.from_numpy(conf),
51
+ }
52
+ return pred
imcui/third_party/rdd/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ .venv
2
+ /build/
3
+ **.egg-info
4
+ **.pyc
5
+ /.idea/
6
+ **/__pycache__/
7
+ weights/
8
+ outputs
imcui/third_party/rdd/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
imcui/third_party/rdd/RDD/RDD.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description: RDD model
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ import numpy as np
6
+ from .utils import NestedTensor, nested_tensor_from_tensor_list, to_pixel_coords, read_config
7
+ from .models.detector import build_detector
8
+ from .models.descriptor import build_descriptor
9
+ from .models.soft_detect import SoftDetect
10
+ from .models.interpolator import InterpolateSparse2d
11
+
12
+ class RDD(nn.Module):
13
+
14
+ def __init__(self, detector, descriptor, detection_threshold=0.5, top_k=4096, train_detector=False, device='cuda'):
15
+ super().__init__()
16
+ self.detector = detector
17
+ self.descriptor = descriptor
18
+ self.interpolator = InterpolateSparse2d('bicubic')
19
+ self.detection_threshold = detection_threshold
20
+ self.top_k = top_k
21
+ self.device = device
22
+ if train_detector:
23
+ for p in self.detector.parameters():
24
+ p.requires_grad = True
25
+ for p in self.descriptor.parameters():
26
+ p.requires_grad = False
27
+ else:
28
+ for p in self.detector.parameters():
29
+ p.requires_grad = False
30
+ for p in self.descriptor.parameters():
31
+ p.requires_grad = True
32
+
33
+ self.softdetect = None
34
+ self.stride = descriptor.stride
35
+
36
+ def train(self, mode=True):
37
+ super().train(mode)
38
+ self.set_softdetect(top_k=500, scores_th=0.2)
39
+
40
+ def eval(self):
41
+ super().eval()
42
+ self.set_softdetect(top_k=self.top_k, scores_th=0.01)
43
+
44
+ def forward(self, samples: NestedTensor):
45
+
46
+ if not isinstance(samples, NestedTensor):
47
+ samples = nested_tensor_from_tensor_list(samples)
48
+
49
+ scoremap = self.detector(samples)
50
+
51
+ feats, matchibility = self.descriptor(samples)
52
+
53
+ return feats, scoremap, matchibility
54
+
55
+ def set_softdetect(self, top_k=4096, scores_th=0.01):
56
+ self.softdetect = SoftDetect(radius=2, top_k=top_k, scores_th=scores_th)
57
+
58
+ @torch.inference_mode()
59
+ def filter(self, matchibility):
60
+ # Filter out keypoints on the border
61
+ B, _, H, W = matchibility.shape
62
+ frame = torch.zeros(B, H, W, device=matchibility.device)
63
+ frame[:, self.stride:-self.stride, self.stride:-self.stride] = 1
64
+ matchibility = matchibility * frame
65
+ return matchibility
66
+
67
+ @torch.inference_mode()
68
+ def extract(self, x):
69
+ if self.softdetect is None:
70
+ self.eval()
71
+
72
+ x, rh1, rw1 = self.preprocess_tensor(x)
73
+ x = x.to(self.device).float()
74
+ B, _, _H1, _W1 = x.shape
75
+ M1, K1, H1 = self.forward(x)
76
+ M1 = F.normalize(M1, dim=1)
77
+
78
+ keypoints, kptscores, scoredispersitys = self.softdetect(K1)
79
+
80
+ keypoints = torch.vstack([keypoints[b].unsqueeze(0) for b in range(B)])
81
+ kptscores = torch.vstack([kptscores[b].unsqueeze(0) for b in range(B)])
82
+
83
+ keypoints = to_pixel_coords(keypoints, _H1, _W1)
84
+
85
+ feats = self.interpolator(M1, keypoints, H = _H1, W = _W1)
86
+
87
+ feats = F.normalize(feats, dim=-1)
88
+
89
+ # Correct kpt scale
90
+ keypoints = keypoints * torch.tensor([rw1,rh1], device=keypoints.device).view(1, -1)
91
+ valid = kptscores > self.detection_threshold
92
+
93
+ return [
94
+ {'keypoints': keypoints[b][valid[b]],
95
+ 'scores': kptscores[b][valid[b]],
96
+ 'descriptors': feats[b][valid[b]]} for b in range(B)
97
+ ]
98
+
99
+ @torch.inference_mode()
100
+ def extract_3rd_party(self, x, model='aliked'):
101
+ """
102
+ one image per batch
103
+ """
104
+ x, rh1, rw1 = self.preprocess_tensor(x)
105
+ B, _, _H1, _W1 = x.shape
106
+ if model == 'aliked':
107
+ from third_party import extract_aliked_kpts
108
+ img = x
109
+ mkpts, scores = extract_aliked_kpts(img, self.device)
110
+ else:
111
+ raise ValueError('Unknown model')
112
+
113
+ M1, _ = self.descriptor(x)
114
+ M1 = F.normalize(M1, dim=1)
115
+
116
+ if mkpts.shape[1] > self.top_k:
117
+ idx = torch.argsort(scores, descending=True)[0][:self.top_k]
118
+ mkpts = mkpts[:,idx]
119
+ scores = scores[:,idx]
120
+
121
+ feats = self.interpolator(M1, mkpts, H = _H1, W = _W1)
122
+ feats = F.normalize(feats, dim=-1)
123
+ mkpts = mkpts * torch.tensor([rw1,rh1], device=mkpts.device).view(1, 1, -1)
124
+
125
+ return [
126
+ {'keypoints': mkpts[b],
127
+ 'scores': scores[b],
128
+ 'descriptors': feats[b]} for b in range(B)
129
+ ]
130
+
131
+ @torch.inference_mode()
132
+ def extract_dense(self, x, n_limit=30000, thr=0.01):
133
+ self.set_softdetect(top_k=n_limit, scores_th=-1)
134
+
135
+ x, rh1, rw1 = self.preprocess_tensor(x)
136
+
137
+ B, _, _H1, _W1 = x.shape
138
+
139
+ M1, K1, H1 = self.forward(x)
140
+ M1 = F.normalize(M1, dim=1)
141
+
142
+ keypoints, kptscores, scoredispersitys = self.softdetect(K1)
143
+
144
+ keypoints = torch.vstack([keypoints[b].unsqueeze(0) for b in range(B)])
145
+ kptscores = torch.vstack([kptscores[b].unsqueeze(0) for b in range(B)])
146
+
147
+ keypoints = to_pixel_coords(keypoints, _H1, _W1)
148
+
149
+ feats = self.interpolator(M1, keypoints, H = _H1, W = _W1)
150
+
151
+ feats = F.normalize(feats, dim=-1)
152
+
153
+ H1 = self.filter(H1)
154
+
155
+ dense_kpts, dense_scores, inds = self.sample_dense_kpts(H1, n_limit=n_limit)
156
+
157
+ dense_keypoints = to_pixel_coords(dense_kpts, _H1, _W1)
158
+
159
+ dense_feats = self.interpolator(M1, dense_keypoints, H = _H1, W = _W1)
160
+
161
+ dense_feats = F.normalize(dense_feats, dim=-1)
162
+
163
+ keypoints = keypoints * torch.tensor([rw1,rh1], device=keypoints.device).view(1, -1)
164
+ dense_keypoints = dense_keypoints * torch.tensor([rw1,rh1], device=dense_keypoints.device).view(1, -1)
165
+
166
+ valid = kptscores > self.detection_threshold
167
+ valid_dense = dense_scores > thr
168
+
169
+ return [
170
+ {'keypoints': keypoints[b][valid[b]],
171
+ 'scores': kptscores[b][valid[b]],
172
+ 'descriptors': feats[b][valid[b]],
173
+ 'keypoints_dense': dense_keypoints[b][valid_dense[b]],
174
+ 'scores_dense': dense_scores[b][valid_dense[b]],
175
+ 'descriptors_dense': dense_feats[b][valid_dense[b]]} for b in range(B)
176
+ ]
177
+
178
+ @torch.inference_mode()
179
+ def sample_dense_kpts(self, keypoint_logits, threshold=0.01, n_limit=30000, force_kpts = True):
180
+
181
+ B, K, H, W = keypoint_logits.shape
182
+
183
+ if n_limit < 0 or n_limit > H*W:
184
+ n_limit = min(H*W - 1, n_limit)
185
+
186
+ scoremap = keypoint_logits.permute(0,2,3,1)
187
+
188
+ scoremap = scoremap.reshape(B, H, W)
189
+
190
+ frame = torch.zeros(B, H, W, device=keypoint_logits.device)
191
+
192
+ frame[:, 1:-1, 1:-1] = 1
193
+
194
+ scoremap = scoremap * frame
195
+
196
+ scoremap = scoremap.reshape(B, H*W)
197
+
198
+ grid = self.get_grid(B, H, W, device = keypoint_logits.device)
199
+
200
+ inds = torch.topk(scoremap, n_limit, dim=1).indices
201
+
202
+ # inds = torch.multinomial(scoremap, top_k, replacement=False)
203
+ kpts = torch.gather(grid, 1, inds[..., None].expand(B, n_limit, 2))
204
+ scoremap = torch.gather(scoremap, 1, inds)
205
+ if force_kpts:
206
+ valid = scoremap > threshold
207
+ kpts = kpts[valid][None]
208
+ scoremap = scoremap[valid][None]
209
+
210
+ return kpts, scoremap, inds
211
+
212
+ def preprocess_tensor(self, x):
213
+ """ Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
214
+ if isinstance(x, np.ndarray) and len(x.shape) == 3:
215
+ x = torch.tensor(x).permute(2,0,1)[None]
216
+ x = x.to(self.device).float()
217
+
218
+ H, W = x.shape[-2:]
219
+
220
+ _H, _W = (H//32) * 32, (W//32) * 32
221
+
222
+ rh, rw = H/_H, W/_W
223
+
224
+ x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
225
+ return x, rh, rw
226
+
227
+ @torch.inference_mode()
228
+ def get_grid(self, B, H, W, device = None):
229
+ x1_n = torch.meshgrid(
230
+ *[
231
+ torch.linspace(
232
+ -1 + 1 / n, 1 - 1 / n, n, device=device
233
+ )
234
+ for n in (B, H, W)
235
+ ]
236
+ )
237
+ x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
238
+ return x1_n
239
+
240
+ def build(config=None, weights=None):
241
+ if config is None:
242
+ config = read_config('./configs/default.yaml')
243
+ if weights is not None:
244
+ config['weights'] = weights
245
+ device = torch.device(config['device'])
246
+ print('config', config)
247
+ detector = build_detector(config)
248
+ descriptor = build_descriptor(config)
249
+ model = RDD(
250
+ detector,
251
+ descriptor,
252
+ detection_threshold=config['detection_threshold'],
253
+ top_k=config['top_k'],
254
+ train_detector=config['train_detector'],
255
+ device=device
256
+ )
257
+ if 'weights' in config and config['weights'] is not None:
258
+ model.load_state_dict(torch.load(config['weights'], map_location='cpu'))
259
+ model.to(device)
260
+ return model
imcui/third_party/rdd/RDD/RDD_helper.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .matchers import DualSoftmaxMatcher, DenseMatcher, LightGlue
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+ import kornia
7
+
8
+ class RDD_helper(nn.Module):
9
+ def __init__(self, RDD):
10
+ super().__init__()
11
+ self.matcher = DualSoftmaxMatcher(inv_temperature = 20, thr = 0.01)
12
+ self.dense_matcher = DenseMatcher(inv_temperature=20, thr=0.01)
13
+ self.RDD = RDD
14
+ self.lg_matcher = None
15
+
16
+ @torch.inference_mode()
17
+ def match(self, img0, img1, thr=0.01, resize=None, top_k=4096):
18
+ if top_k is not None and top_k != self.RDD.top_k:
19
+ self.RDD.top_k = top_k
20
+ self.RDD.set_softdetect(top_k=top_k)
21
+
22
+ img0, scale0 = self.parse_input(img0, resize)
23
+ img1, scale1 = self.parse_input(img1, resize)
24
+
25
+ out0 = self.RDD.extract(img0)[0]
26
+ out1 = self.RDD.extract(img1)[0]
27
+
28
+ # get top_k confident matches
29
+ mkpts0, mkpts1, conf = self.matcher(out0, out1, thr)
30
+
31
+ scale0 = 1.0 / scale0
32
+ scale1 = 1.0 / scale1
33
+
34
+ mkpts0 = mkpts0 * scale0
35
+ mkpts1 = mkpts1 * scale1
36
+
37
+ return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy()
38
+
39
+ @torch.inference_mode()
40
+ def match_lg(self, img0, img1, thr=0.01, resize=None, top_k=4096):
41
+ if self.lg_matcher is None:
42
+ lg_conf = {
43
+ "name": "lightglue", # just for interfacing
44
+ "input_dim": 256, # input descriptor dimension (autoselected from weights)
45
+ "descriptor_dim": 256,
46
+ "add_scale_ori": False,
47
+ "n_layers": 9,
48
+ "num_heads": 4,
49
+ "flash": True, # enable FlashAttention if available.
50
+ "mp": False, # enable mixed precision
51
+ "filter_threshold": 0.01, # match threshold
52
+ "depth_confidence": -1, # depth confidence threshold
53
+ "width_confidence": -1, # width confidence threshold
54
+ "weights": './weights/RDD_lg-v2.pth', # path to the weights
55
+ }
56
+ self.lg_matcher = LightGlue(features='rdd', conf=lg_conf).to(self.RDD.device)
57
+
58
+ if top_k is not None and top_k != self.RDD.top_k:
59
+ self.RDD.top_k = top_k
60
+ self.RDD.set_softdetect(top_k=top_k)
61
+
62
+ img0, scale0 = self.parse_input(img0, resize=resize)
63
+ img1, scale1 = self.parse_input(img1, resize=resize)
64
+
65
+ size0 = torch.tensor(img0.shape[-2:])[None]
66
+ size1 = torch.tensor(img1.shape[-2:])[None]
67
+
68
+ out0 = self.RDD.extract(img0)[0]
69
+ out1 = self.RDD.extract(img1)[0]
70
+
71
+ # get top_k confident matches
72
+ image0_data = {
73
+ 'keypoints': out0['keypoints'][None],
74
+ 'descriptors': out0['descriptors'][None],
75
+ 'image_size': size0,
76
+ }
77
+
78
+ image1_data = {
79
+ 'keypoints': out1['keypoints'][None],
80
+ 'descriptors': out1['descriptors'][None],
81
+ 'image_size': size1,
82
+ }
83
+
84
+ pred = {}
85
+
86
+ with torch.no_grad():
87
+ pred.update({'image0': image0_data, 'image1': image1_data})
88
+ pred.update(self.lg_matcher({**pred}))
89
+
90
+ kpts0 = pred['image0']['keypoints'][0]
91
+ kpts1 = pred['image1']['keypoints'][0]
92
+
93
+ matches = pred['matches'][0]
94
+
95
+ mkpts0 = kpts0[matches[... , 0]]
96
+ mkpts1 = kpts1[matches[... , 1]]
97
+ conf = pred['scores'][0]
98
+
99
+ valid_mask = conf > thr
100
+ mkpts0 = mkpts0[valid_mask]
101
+ mkpts1 = mkpts1[valid_mask]
102
+ conf = conf[valid_mask]
103
+
104
+ scale0 = 1.0 / scale0
105
+ scale1 = 1.0 / scale1
106
+ mkpts0 = mkpts0 * scale0
107
+ mkpts1 = mkpts1 * scale1
108
+
109
+ return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy()
110
+
111
+ @torch.inference_mode()
112
+ def match_dense(self, img0, img1, thr=0.01, resize=None):
113
+
114
+ img0, scale0 = self.parse_input(img0, resize=resize)
115
+ img1, scale1 = self.parse_input(img1, resize=resize)
116
+
117
+ out0 = self.RDD.extract_dense(img0)[0]
118
+ out1 = self.RDD.extract_dense(img1)[0]
119
+
120
+ # get top_k confident matches
121
+ mkpts0, mkpts1, conf = self.dense_matcher(out0, out1, thr, err_thr=self.RDD.stride)
122
+
123
+ scale0 = 1.0 / scale0
124
+ scale1 = 1.0 / scale1
125
+
126
+ mkpts0 = mkpts0 * scale0
127
+ mkpts1 = mkpts1 * scale1
128
+
129
+ return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy()
130
+
131
+ @torch.inference_mode()
132
+ def match_3rd_party(self, img0, img1, model='aliked', resize=None, thr=0.01):
133
+ img0, scale0 = self.parse_input(img0, resize=resize)
134
+ img1, scale1 = self.parse_input(img1, resize=resize)
135
+
136
+ out0 = self.RDD.extract_3rd_party(img0, model=model)[0]
137
+ out1 = self.RDD.extract_3rd_party(img1, model=model)[0]
138
+
139
+ mkpts0, mkpts1, conf = self.matcher(out0, out1, thr)
140
+
141
+ scale0 = 1.0 / scale0
142
+ scale1 = 1.0 / scale1
143
+
144
+ mkpts0 = mkpts0 * scale0
145
+ mkpts1 = mkpts1 * scale1
146
+
147
+ return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy()
148
+
149
+ def parse_input(self, x, resize=None):
150
+ if len(x.shape) == 3:
151
+ x = x[None, ...]
152
+
153
+ if isinstance(x, np.ndarray):
154
+ x = torch.tensor(x).permute(0,3,1,2)/255
155
+
156
+ h, w = x.shape[-2:]
157
+ size = h, w
158
+
159
+ if resize is not None:
160
+ size = self.get_new_image_size(h, w, resize)
161
+ x = kornia.geometry.transform.resize(
162
+ x,
163
+ size,
164
+ side='long',
165
+ antialias=True,
166
+ align_corners=None,
167
+ interpolation='bilinear',
168
+ )
169
+ scale = torch.Tensor([x.shape[-1] / w, x.shape[-2] / h]).to(self.RDD.device)
170
+
171
+ return x, scale
172
+
173
+ def get_new_image_size(self, h, w, resize=1600):
174
+ aspect_ratio = w / h
175
+ size = int(resize / aspect_ratio), resize
176
+
177
+ size = list(map(lambda x: int(x // 32 * 32), size)) # make sure size is divisible by 32
178
+
179
+ return size
imcui/third_party/rdd/RDD/dataset/__init__.py ADDED
File without changes
imcui/third_party/rdd/RDD/dataset/megadepth/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .megadepth import *
2
+ from .megadepth_warper import *
imcui/third_party/rdd/RDD/dataset/megadepth/megadepth.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import h5py
4
+ import torch
5
+ import pickle
6
+ import numpy as np
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+ from pathlib import Path
10
+ from torchvision import transforms
11
+ from torch.utils.data import Dataset
12
+
13
+ import cv2
14
+ from .utils import scale_intrinsics, warp_depth, warp_points2d
15
+
16
+ class MegaDepthDataset(Dataset):
17
+ def __init__(
18
+ self,
19
+ root,
20
+ npz_path,
21
+ num_per_scene=100,
22
+ image_size=256,
23
+ min_overlap_score=0.1,
24
+ max_overlap_score=0.9,
25
+ gray=False,
26
+ crop_or_scale='scale', # crop, scale, crop_scale
27
+ train=True,
28
+ ):
29
+ self.data_path = Path(root)
30
+ self.num_per_scene = num_per_scene
31
+ self.train = train
32
+ self.image_size = image_size
33
+ self.gray = gray
34
+ self.crop_or_scale = crop_or_scale
35
+
36
+ self.scene_info = dict(np.load(npz_path, allow_pickle=True))
37
+ self.pair_infos = self.scene_info['pair_infos'].copy()
38
+ del self.scene_info['pair_infos']
39
+ self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score and pair_info[1] < max_overlap_score]
40
+ if len(self.pair_infos) > num_per_scene:
41
+ indices = np.random.choice(len(self.pair_infos), num_per_scene, replace=False)
42
+ self.pair_infos = [self.pair_infos[idx] for idx in indices]
43
+ self.transforms = transforms.Compose([transforms.ToPILImage(),
44
+ transforms.ToTensor()])
45
+
46
+ def __len__(self):
47
+ return len(self.pair_infos)
48
+
49
+ def recover_pair(self, idx):
50
+ (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx % len(self)]
51
+
52
+ img_name1 = self.scene_info['image_paths'][idx0]
53
+ img_name2 = self.scene_info['image_paths'][idx1]
54
+
55
+ depth1 = '/'.join([self.scene_info['depth_paths'][idx0].replace('phoenix/S6/zl548/MegaDepth_v1', 'depth_undistorted').split('/')[i] for i in [0, 1, -1]])
56
+ depth2 = '/'.join([self.scene_info['depth_paths'][idx1].replace('phoenix/S6/zl548/MegaDepth_v1', 'depth_undistorted').split('/')[i] for i in [0, 1, -1]])
57
+
58
+ depth_path1 = self.data_path / depth1
59
+ with h5py.File(depth_path1, 'r') as hdf5_file:
60
+ depth1 = np.array(hdf5_file['/depth'])
61
+ assert (np.min(depth1) >= 0)
62
+ image_path1 = self.data_path / img_name1
63
+ image1 = Image.open(image_path1)
64
+ if image1.mode != 'RGB':
65
+ image1 = image1.convert('RGB')
66
+ image1 = np.array(image1)
67
+ assert (image1.shape[0] == depth1.shape[0] and image1.shape[1] == depth1.shape[1])
68
+ intrinsics1 = self.scene_info['intrinsics'][idx0].copy()
69
+ pose1 = self.scene_info['poses'][idx0]
70
+
71
+ depth_path2 = self.data_path / depth2
72
+ with h5py.File(depth_path2, 'r') as hdf5_file:
73
+ depth2 = np.array(hdf5_file['/depth'])
74
+ assert (np.min(depth2) >= 0)
75
+ image_path2 = self.data_path / img_name2
76
+ image2 = Image.open(image_path2)
77
+ if image2.mode != 'RGB':
78
+ image2 = image2.convert('RGB')
79
+ image2 = np.array(image2)
80
+ assert (image2.shape[0] == depth2.shape[0] and image2.shape[1] == depth2.shape[1])
81
+ intrinsics2 = self.scene_info['intrinsics'][idx1].copy()
82
+ pose2 = self.scene_info['poses'][idx1]
83
+
84
+ pose12 = pose2 @ np.linalg.inv(pose1)
85
+ pose21 = np.linalg.inv(pose12)
86
+
87
+ if self.train:
88
+ if "crop" in self.crop_or_scale:
89
+ # ================================================= compute central_match
90
+ DOWNSAMPLE = 10
91
+ # resize to speed up
92
+ depth1s = cv2.resize(depth1, (depth1.shape[1] // DOWNSAMPLE, depth1.shape[0] // DOWNSAMPLE))
93
+ depth2s = cv2.resize(depth2, (depth2.shape[1] // DOWNSAMPLE, depth2.shape[0] // DOWNSAMPLE))
94
+ intrinsic1s = scale_intrinsics(intrinsics1, (DOWNSAMPLE, DOWNSAMPLE))
95
+ intrinsic2s = scale_intrinsics(intrinsics2, (DOWNSAMPLE, DOWNSAMPLE))
96
+
97
+ # warp
98
+ depth12s = warp_depth(depth1s, intrinsic1s, intrinsic2s, pose12, depth2s.shape)
99
+ depth21s = warp_depth(depth2s, intrinsic2s, intrinsic1s, pose21, depth1s.shape)
100
+
101
+ depth12s[depth12s < 0] = 0
102
+ depth21s[depth21s < 0] = 0
103
+
104
+ valid12s = np.logical_and(depth12s > 0, depth2s > 0)
105
+ valid21s = np.logical_and(depth21s > 0, depth1s > 0)
106
+
107
+ pos1 = np.array(valid21s.nonzero())
108
+ try:
109
+ idx1_random = np.random.choice(np.arange(pos1.shape[1]), 1)
110
+ uv1s = pos1[:, idx1_random][[1, 0]].reshape(1, 2)
111
+ d1s = np.array(depth1s[uv1s[0, 1], uv1s[0, 0]]).reshape(1, 1)
112
+
113
+ uv12s, z12s = warp_points2d(uv1s, d1s, intrinsic1s, intrinsic2s, pose12)
114
+
115
+ uv1 = uv1s[0] * DOWNSAMPLE
116
+ uv2 = uv12s[0] * DOWNSAMPLE
117
+ except ValueError:
118
+ uv1 = [depth1.shape[1] / 2, depth1.shape[0] / 2]
119
+ uv2 = [depth2.shape[1] / 2, depth2.shape[0] / 2]
120
+
121
+ central_match = [uv1[1], uv1[0], uv2[1], uv2[0]]
122
+ # ================================================= compute central_match
123
+
124
+ if self.crop_or_scale == 'crop':
125
+ # =============== padding
126
+ h1, w1, _ = image1.shape
127
+ h2, w2, _ = image2.shape
128
+ if h1 < self.image_size:
129
+ padding = np.zeros((self.image_size - h1, w1, 3))
130
+ image1 = np.concatenate([image1, padding], axis=0).astype(np.uint8)
131
+ depth1 = np.concatenate([depth1, padding[:, :, 0]], axis=0).astype(np.float32)
132
+ h1, w1, _ = image1.shape
133
+ if w1 < self.image_size:
134
+ padding = np.zeros((h1, self.image_size - w1, 3))
135
+ image1 = np.concatenate([image1, padding], axis=1).astype(np.uint8)
136
+ depth1 = np.concatenate([depth1, padding[:, :, 0]], axis=1).astype(np.float32)
137
+ if h2 < self.image_size:
138
+ padding = np.zeros((self.image_size - h2, w2, 3))
139
+ image2 = np.concatenate([image2, padding], axis=0).astype(np.uint8)
140
+ depth2 = np.concatenate([depth2, padding[:, :, 0]], axis=0).astype(np.float32)
141
+ h2, w2, _ = image2.shape
142
+ if w2 < self.image_size:
143
+ padding = np.zeros((h2, self.image_size - w2, 3))
144
+ image2 = np.concatenate([image2, padding], axis=1).astype(np.uint8)
145
+ depth2 = np.concatenate([depth2, padding[:, :, 0]], axis=1).astype(np.float32)
146
+ # =============== padding
147
+ image1, bbox1, image2, bbox2 = self.crop(image1, image2, central_match)
148
+
149
+ depth1 = depth1[bbox1[0]: bbox1[0] + self.image_size, bbox1[1]: bbox1[1] + self.image_size]
150
+ depth2 = depth2[bbox2[0]: bbox2[0] + self.image_size, bbox2[1]: bbox2[1] + self.image_size]
151
+ elif self.crop_or_scale == 'scale':
152
+ image1, depth1, intrinsics1 = self.scale(image1, depth1, intrinsics1)
153
+ image2, depth2, intrinsics2 = self.scale(image2, depth2, intrinsics2)
154
+ bbox1 = bbox2 = np.array([0., 0.])
155
+ elif self.crop_or_scale == 'crop_scale':
156
+ bbox1 = bbox2 = np.array([0., 0.])
157
+ image1, depth1, intrinsics1 = self.crop_scale(image1, depth1, intrinsics1, central_match[:2])
158
+ image2, depth2, intrinsics2 = self.crop_scale(image2, depth2, intrinsics2, central_match[2:])
159
+ else:
160
+ raise RuntimeError(f"Unkown type {self.crop_or_scale}")
161
+ else:
162
+ bbox1 = bbox2 = np.array([0., 0.])
163
+
164
+ return (image1, depth1, intrinsics1, pose12, bbox1,
165
+ image2, depth2, intrinsics2, pose21, bbox2)
166
+
167
+ def scale(self, image, depth, intrinsic):
168
+ img_size_org = image.shape
169
+ image = cv2.resize(image, (self.image_size, self.image_size))
170
+ depth = cv2.resize(depth, (self.image_size, self.image_size))
171
+ intrinsic = scale_intrinsics(intrinsic, (img_size_org[1] / self.image_size, img_size_org[0] / self.image_size))
172
+ return image, depth, intrinsic
173
+
174
+ def crop_scale(self, image, depth, intrinsic, centeral):
175
+ h_org, w_org, three = image.shape
176
+ image_size = min(h_org, w_org)
177
+ if h_org > w_org:
178
+ if centeral[1] - image_size // 2 < 0:
179
+ h_start = 0
180
+ elif centeral[1] + image_size // 2 > h_org:
181
+ h_start = h_org - image_size
182
+ else:
183
+ h_start = int(centeral[1]) - image_size // 2
184
+ w_start = 0
185
+ else:
186
+ if centeral[0] - image_size // 2 < 0:
187
+ w_start = 0
188
+ elif centeral[0] + image_size // 2 > w_org:
189
+ w_start = w_org - image_size
190
+ else:
191
+ w_start = int(centeral[0]) - image_size // 2
192
+ h_start = 0
193
+
194
+ croped_image = image[h_start: h_start + image_size, w_start: w_start + image_size]
195
+ croped_depth = depth[h_start: h_start + image_size, w_start: w_start + image_size]
196
+ intrinsic[0, 2] = intrinsic[0, 2] - w_start
197
+ intrinsic[1, 2] = intrinsic[1, 2] - h_start
198
+
199
+ image = cv2.resize(croped_image, (self.image_size, self.image_size))
200
+ depth = cv2.resize(croped_depth, (self.image_size, self.image_size))
201
+ intrinsic = scale_intrinsics(intrinsic, (image_size / self.image_size, image_size / self.image_size))
202
+
203
+ return image, depth, intrinsic
204
+
205
+ def crop(self, image1, image2, central_match):
206
+ bbox1_i = max(int(central_match[0]) - self.image_size // 2, 0)
207
+ if bbox1_i + self.image_size >= image1.shape[0]:
208
+ bbox1_i = image1.shape[0] - self.image_size
209
+ bbox1_j = max(int(central_match[1]) - self.image_size // 2, 0)
210
+ if bbox1_j + self.image_size >= image1.shape[1]:
211
+ bbox1_j = image1.shape[1] - self.image_size
212
+
213
+ bbox2_i = max(int(central_match[2]) - self.image_size // 2, 0)
214
+ if bbox2_i + self.image_size >= image2.shape[0]:
215
+ bbox2_i = image2.shape[0] - self.image_size
216
+ bbox2_j = max(int(central_match[3]) - self.image_size // 2, 0)
217
+ if bbox2_j + self.image_size >= image2.shape[1]:
218
+ bbox2_j = image2.shape[1] - self.image_size
219
+
220
+ return (image1[bbox1_i: bbox1_i + self.image_size, bbox1_j: bbox1_j + self.image_size],
221
+ np.array([bbox1_i, bbox1_j]),
222
+ image2[bbox2_i: bbox2_i + self.image_size, bbox2_j: bbox2_j + self.image_size],
223
+ np.array([bbox2_i, bbox2_j])
224
+ )
225
+
226
+ def __getitem__(self, idx):
227
+ (image1, depth1, intrinsics1, pose12, bbox1,
228
+ image2, depth2, intrinsics2, pose21, bbox2) \
229
+ = self.recover_pair(idx)
230
+
231
+ if self.gray:
232
+ gray1 = cv2.cvtColor(image1, cv2.COLOR_RGB2GRAY)
233
+ gray2 = cv2.cvtColor(image2, cv2.COLOR_RGB2GRAY)
234
+ gray1 = transforms.ToTensor()(gray1)
235
+ gray2 = transforms.ToTensor()(gray2)
236
+ if self.transforms is not None:
237
+ image1, image2 = self.transforms(image1), self.transforms(image2) # [C,H,W]
238
+ ret = {'image0': image1,
239
+ 'image1': image2,
240
+ 'angle': 0,
241
+ 'overlap': self.pair_infos[idx][1],
242
+ 'warp01_params': {'mode': 'se3',
243
+ 'width': self.image_size if self.train else image1.shape[2],
244
+ 'height': self.image_size if self.train else image1.shape[1],
245
+ 'pose01': torch.from_numpy(pose12.astype(np.float32)),
246
+ 'bbox0': torch.from_numpy(bbox1.astype(np.float32)),
247
+ 'bbox1': torch.from_numpy(bbox2.astype(np.float32)),
248
+ 'depth0': torch.from_numpy(depth1.astype(np.float32)),
249
+ 'depth1': torch.from_numpy(depth2.astype(np.float32)),
250
+ 'intrinsics0': torch.from_numpy(intrinsics1.astype(np.float32)),
251
+ 'intrinsics1': torch.from_numpy(intrinsics2.astype(np.float32))},
252
+ 'warp10_params': {'mode': 'se3',
253
+ 'width': self.image_size if self.train else image2.shape[2],
254
+ 'height': self.image_size if self.train else image2.shape[2],
255
+ 'pose01': torch.from_numpy(pose21.astype(np.float32)),
256
+ 'bbox0': torch.from_numpy(bbox2.astype(np.float32)),
257
+ 'bbox1': torch.from_numpy(bbox1.astype(np.float32)),
258
+ 'depth0': torch.from_numpy(depth2.astype(np.float32)),
259
+ 'depth1': torch.from_numpy(depth1.astype(np.float32)),
260
+ 'intrinsics0': torch.from_numpy(intrinsics2.astype(np.float32)),
261
+ 'intrinsics1': torch.from_numpy(intrinsics1.astype(np.float32))},
262
+ }
263
+ if self.gray:
264
+ ret['gray0'] = gray1
265
+ ret['gray1'] = gray2
266
+ return ret
267
+
268
+
269
+ if __name__ == '__main__':
270
+ from torch.utils.data import DataLoader
271
+ import matplotlib.pyplot as plt
272
+
273
+
274
+ def visualize(image0, image1, depth0, depth1):
275
+ # visualize image and depth
276
+ plt.figure(figsize=(9, 9))
277
+ plt.subplot(2, 2, 1)
278
+ plt.imshow(image0, cmap='gray')
279
+ plt.subplot(2, 2, 2)
280
+ plt.imshow(depth0)
281
+ plt.subplot(2, 2, 3)
282
+ plt.imshow(image1, cmap='gray')
283
+ plt.subplot(2, 2, 4)
284
+ plt.imshow(depth1)
285
+ plt.show()
286
+
287
+
288
+ dataset = MegaDepthDataset( # root='../data/megadepth',
289
+ root='../data/imw2020val',
290
+ train=False,
291
+ using_cache=True,
292
+ pairs_per_scene=100,
293
+ image_size=256,
294
+ colorjit=True,
295
+ gray=False,
296
+ crop_or_scale='scale',
297
+ )
298
+ dataset.build_dataset()
299
+
300
+ batch_size = 2
301
+
302
+ loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)
303
+
304
+ for idx, batch in enumerate(tqdm(loader)):
305
+ image0, image1 = batch['image0'], batch['image1'] # [B,3,H,W]
306
+ depth0, depth1 = batch['warp01_params']['depth0'], batch['warp01_params']['depth1'] # [B,H,W]
307
+ intrinsics0, intrinsics1 = batch['warp01_params']['intrinsics0'], batch['warp01_params'][
308
+ 'intrinsics1'] # [B,3,3]
309
+
310
+ batch_size, channels, h, w = image0.shape
311
+
312
+ for b_idx in range(batch_size):
313
+ visualize(image0[b_idx].permute(1, 2, 0), image1[b_idx].permute(1, 2, 0), depth0[b_idx], depth1[b_idx])
imcui/third_party/rdd/RDD/dataset/megadepth/megadepth_warper.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from kornia.utils import create_meshgrid
3
+ import matplotlib.pyplot as plt
4
+ import pdb
5
+ from .utils import warp
6
+
7
+ @torch.no_grad()
8
+ def spvs_coarse(data, scale = 8):
9
+ N, _, H0, W0 = data['image0'].shape
10
+ _, _, H1, W1 = data['image1'].shape
11
+ device = data['image0'].device
12
+ corrs = []
13
+ for idx in range(N):
14
+ warp01_params = {}
15
+ for k, v in data['warp01_params'].items():
16
+ if isinstance(v[idx], torch.Tensor):
17
+ warp01_params[k] = v[idx].to(device)
18
+ else:
19
+ warp01_params[k] = v[idx]
20
+ warp10_params = {}
21
+ for k, v in data['warp10_params'].items():
22
+ if isinstance(v[idx], torch.Tensor):
23
+ warp10_params[k] = v[idx].to(device)
24
+ else:
25
+ warp10_params[k] = v[idx]
26
+
27
+ # create kpts
28
+ h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
29
+ grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(h1*w1, 2) # [N, hw, 2]
30
+
31
+ # normalize kpts
32
+ grid_pt1_c = grid_pt1_c * scale
33
+
34
+
35
+
36
+ try:
37
+ grid_pt1_c_valid, grid_pt10_c, ids1, ids1_out = warp(grid_pt1_c, warp10_params)
38
+ grid_pt10_c_valid, grid_pt01_c, ids0, ids0_out = warp(grid_pt10_c, warp01_params)
39
+
40
+ # check reproj error
41
+ grid_pt1_c_valid = grid_pt1_c_valid[ids0]
42
+ dist = torch.linalg.norm(grid_pt1_c_valid - grid_pt01_c, dim=-1)
43
+
44
+ mask_mutual = (dist < 1.5)
45
+
46
+ #get correspondences
47
+ pts = torch.cat([grid_pt10_c_valid[mask_mutual] / scale,
48
+ grid_pt01_c[mask_mutual] / scale], dim=-1)
49
+ #remove repeated correspondences
50
+ lut_mat12 = torch.ones((h1, w1, 4), device = device, dtype = torch.float32) * -1
51
+ lut_mat21 = torch.clone(lut_mat12)
52
+ src_pts = pts[:, :2]
53
+ tgt_pts = pts[:, 2:]
54
+
55
+ lut_mat12[src_pts[:,1].long(), src_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1)
56
+ mask_valid12 = torch.all(lut_mat12 >= 0, dim=-1)
57
+ points = lut_mat12[mask_valid12]
58
+
59
+ #Target-src check
60
+ src_pts, tgt_pts = points[:, :2], points[:, 2:]
61
+ lut_mat21[tgt_pts[:,1].long(), tgt_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1)
62
+ mask_valid21 = torch.all(lut_mat21 >= 0, dim=-1)
63
+ points = lut_mat21[mask_valid21]
64
+
65
+ corrs.append(points)
66
+ except:
67
+ corrs.append(torch.zeros((0, 4), device = device))
68
+ #pdb.set_trace()
69
+ #print('..')
70
+
71
+ #Plot for debug purposes
72
+ # for i in range(len(corrs)):
73
+ # plot_corrs(data['image0'][i], data['image1'][i], corrs[i][:, :2]*8, corrs[i][:, 2:]*8)
74
+
75
+ return corrs
imcui/third_party/rdd/RDD/dataset/megadepth/utils.py ADDED
@@ -0,0 +1,848 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
3
+ https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/
4
+
5
+ MegaDepth data handling was adapted from
6
+ LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py
7
+ """
8
+
9
+ import io
10
+ import cv2
11
+ import numpy as np
12
+ import h5py
13
+ import torch
14
+ from numpy.linalg import inv
15
+ from kornia.geometry.epipolar import essential_from_Rt
16
+ from kornia.geometry.epipolar import fundamental_from_essential
17
+
18
+ import cv2
19
+ import torch
20
+ import numpy as np
21
+ from numba import jit
22
+ from copy import deepcopy
23
+
24
+ try:
25
+ from utils.project_depth_nn_cython_pkg import project_depth_nn_cython
26
+
27
+ nn_cython = True
28
+ except:
29
+ print('\033[1;41;37mWarning: using python to project depth!!!\033[0m')
30
+
31
+ nn_cython = False
32
+
33
+
34
+ class EmptyTensorError(Exception):
35
+ pass
36
+
37
+
38
+ def mutual_NN(cross_matrix, mode: str = 'min'):
39
+ """
40
+ compute mutual nearest neighbor from a cross_matrix, non-differentiable function
41
+ :param cross_matrix: N0xN1
42
+ :param mode: 'min': mutual minimum; 'max':mutual maximum
43
+ :return: index0,index1, Mx2
44
+ """
45
+ if mode == 'min':
46
+ nn0 = cross_matrix == cross_matrix.min(dim=1, keepdim=True)[0]
47
+ nn1 = cross_matrix == cross_matrix.min(dim=0, keepdim=True)[0]
48
+ elif mode == 'max':
49
+ nn0 = cross_matrix == cross_matrix.max(dim=1, keepdim=True)[0]
50
+ nn1 = cross_matrix == cross_matrix.max(dim=0, keepdim=True)[0]
51
+ else:
52
+ raise TypeError("error mode, must be 'min' or 'max'.")
53
+
54
+ mutual_nn = nn0 * nn1
55
+
56
+ return torch.nonzero(mutual_nn, as_tuple=False)
57
+
58
+
59
+ def mutual_argmax(value, mask=None, as_tuple=True):
60
+ """
61
+ Args:
62
+ value: MxN
63
+ mask: MxN
64
+
65
+ Returns:
66
+
67
+ """
68
+ value = value - value.min() # convert to non-negative tensor
69
+ if mask is not None:
70
+ value = value * mask
71
+
72
+ max0 = value.max(dim=1, keepdim=True) # the col index the max value in each row
73
+ max1 = value.max(dim=0, keepdim=True)
74
+
75
+ valid_max0 = value == max0[0]
76
+ valid_max1 = value == max1[0]
77
+
78
+ mutual = valid_max0 * valid_max1
79
+ if mask is not None:
80
+ mutual = mutual * mask
81
+
82
+ return mutual.nonzero(as_tuple=as_tuple)
83
+
84
+
85
+ def mutual_argmin(value, mask=None):
86
+ return mutual_argmax(-value, mask)
87
+
88
+
89
+ def compute_keypoints_distance(kpts0, kpts1, p=2):
90
+ """
91
+ Args:
92
+ kpts0: torch.tensor [M,2]
93
+ kpts1: torch.tensor [N,2]
94
+ p: (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm
95
+
96
+ Returns:
97
+ dist, torch.tensor [N,M]
98
+ """
99
+ dist = kpts0[:, None, :] - kpts1[None, :, :] # [M,N,2]
100
+ dist = torch.norm(dist, p=p, dim=2) # [M,N]
101
+ return dist
102
+
103
+
104
+ def keypoints_normal2pixel(kpts_normal, w, h):
105
+ wh = kpts_normal[0].new_tensor([[w - 1, h - 1]])
106
+ kpts_pixel = [(kpts + 1) / 2 * wh for kpts in kpts_normal]
107
+ return kpts_pixel
108
+
109
+
110
+ def plot_keypoints(image, kpts, radius=2, color=(255, 0, 0)):
111
+ image = image.cpu().detach().numpy() if isinstance(image, torch.Tensor) else image
112
+ kpts = kpts.cpu().detach().numpy() if isinstance(kpts, torch.Tensor) else kpts
113
+
114
+ if image.dtype is not np.dtype('uint8'):
115
+ image = image * 255
116
+ image = image.astype(np.uint8)
117
+
118
+ if len(image.shape) == 2 or image.shape[2] == 1:
119
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
120
+
121
+ out = np.ascontiguousarray(deepcopy(image))
122
+ kpts = np.round(kpts).astype(int)
123
+
124
+ for kpt in kpts:
125
+ y0, x0 = kpt
126
+ cv2.drawMarker(out, (x0, y0), color, cv2.MARKER_CROSS, radius)
127
+
128
+ # cv2.circle(out, (x0, y0), radius, color, -1, lineType=cv2.LINE_4)
129
+ return out
130
+
131
+
132
+ def plot_matches(image0, image1, kpts0, kpts1, radius=2, color=(255, 0, 0), mcolor=(0, 255, 0), layout='lr'):
133
+ image0 = image0.cpu().detach().numpy() if isinstance(image0, torch.Tensor) else image0
134
+ image1 = image1.cpu().detach().numpy() if isinstance(image1, torch.Tensor) else image1
135
+ kpts0 = kpts0.cpu().detach().numpy() if isinstance(kpts0, torch.Tensor) else kpts0
136
+ kpts1 = kpts1.cpu().detach().numpy() if isinstance(kpts1, torch.Tensor) else kpts1
137
+
138
+ out0 = plot_keypoints(image0, kpts0, radius, color)
139
+ out1 = plot_keypoints(image1, kpts1, radius, color)
140
+
141
+ H0, W0 = image0.shape[0], image0.shape[1]
142
+ H1, W1 = image1.shape[0], image1.shape[1]
143
+
144
+ if layout == "lr":
145
+ H, W = max(H0, H1), W0 + W1
146
+ out = 255 * np.ones((H, W, 3), np.uint8)
147
+ out[:H0, :W0, :] = out0
148
+ out[:H1, W0:, :] = out1
149
+ elif layout == "ud":
150
+ H, W = H0 + H1, max(W0, W1)
151
+ out = 255 * np.ones((H, W, 3), np.uint8)
152
+ out[:H0, :W0, :] = out0
153
+ out[H0:, :W1, :] = out1
154
+ else:
155
+ raise ValueError("The layout must be 'lr' or 'ud'!")
156
+
157
+ kpts0 = np.round(kpts0).astype(int)
158
+ kpts1 = np.round(kpts1).astype(int)
159
+
160
+ for kpt0, kpt1 in zip(kpts0, kpts1):
161
+ (y0, x0), (y1, x1) = kpt0, kpt1
162
+
163
+ if layout == "lr":
164
+ cv2.line(out, (x0, y0), (x1 + W0, y1), color=mcolor, thickness=1, lineType=cv2.LINE_AA)
165
+ elif layout == "ud":
166
+ cv2.line(out, (x0, y0), (x1, y1 + H0), color=mcolor, thickness=1, lineType=cv2.LINE_AA)
167
+
168
+ return out
169
+
170
+
171
+ def interpolate_depth(pos, depth):
172
+ pos = pos.t()[[1, 0]] # Nx2 -> 2xN; w,h -> h,w(i,j)
173
+
174
+ # =============================================== from d2-net
175
+ device = pos.device
176
+
177
+ ids = torch.arange(0, pos.size(1), device=device)
178
+
179
+ h, w = depth.size()
180
+
181
+ i = pos[0, :].detach() # TODO: changed here
182
+ j = pos[1, :].detach() # TODO: changed here
183
+
184
+ # Valid corners
185
+ i_top_left = torch.floor(i).long()
186
+ j_top_left = torch.floor(j).long()
187
+ valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0)
188
+
189
+ i_top_right = torch.floor(i).long()
190
+ j_top_right = torch.ceil(j).long()
191
+ valid_top_right = torch.min(i_top_right >= 0, j_top_right < w)
192
+
193
+ i_bottom_left = torch.ceil(i).long()
194
+ j_bottom_left = torch.floor(j).long()
195
+ valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0)
196
+
197
+ i_bottom_right = torch.ceil(i).long()
198
+ j_bottom_right = torch.ceil(j).long()
199
+ valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w)
200
+
201
+ valid_corners = torch.min(torch.min(valid_top_left, valid_top_right),
202
+ torch.min(valid_bottom_left, valid_bottom_right))
203
+
204
+ i_top_left = i_top_left[valid_corners]
205
+ j_top_left = j_top_left[valid_corners]
206
+
207
+ i_top_right = i_top_right[valid_corners]
208
+ j_top_right = j_top_right[valid_corners]
209
+
210
+ i_bottom_left = i_bottom_left[valid_corners]
211
+ j_bottom_left = j_bottom_left[valid_corners]
212
+
213
+ i_bottom_right = i_bottom_right[valid_corners]
214
+ j_bottom_right = j_bottom_right[valid_corners]
215
+
216
+ ids = ids[valid_corners]
217
+ ids_valid_corners = deepcopy(ids)
218
+ if ids.size(0) == 0:
219
+ # raise ValueError('empty tensor: ids')
220
+ raise EmptyTensorError
221
+
222
+ # Valid depth
223
+ valid_depth = torch.min(torch.min(depth[i_top_left, j_top_left] > 0,
224
+ depth[i_top_right, j_top_right] > 0),
225
+ torch.min(depth[i_bottom_left, j_bottom_left] > 0,
226
+ depth[i_bottom_right, j_bottom_right] > 0))
227
+
228
+ i_top_left = i_top_left[valid_depth]
229
+ j_top_left = j_top_left[valid_depth]
230
+
231
+ i_top_right = i_top_right[valid_depth]
232
+ j_top_right = j_top_right[valid_depth]
233
+
234
+ i_bottom_left = i_bottom_left[valid_depth]
235
+ j_bottom_left = j_bottom_left[valid_depth]
236
+
237
+ i_bottom_right = i_bottom_right[valid_depth]
238
+ j_bottom_right = j_bottom_right[valid_depth]
239
+
240
+ ids = ids[valid_depth]
241
+ ids_valid_depth = deepcopy(ids)
242
+ if ids.size(0) == 0:
243
+ # raise ValueError('empty tensor: ids')
244
+ raise EmptyTensorError
245
+
246
+ # Interpolation
247
+ i = i[ids]
248
+ j = j[ids]
249
+ dist_i_top_left = i - i_top_left.float()
250
+ dist_j_top_left = j - j_top_left.float()
251
+ w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
252
+ w_top_right = (1 - dist_i_top_left) * dist_j_top_left
253
+ w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
254
+ w_bottom_right = dist_i_top_left * dist_j_top_left
255
+
256
+ interpolated_depth = (w_top_left * depth[i_top_left, j_top_left] +
257
+ w_top_right * depth[i_top_right, j_top_right] +
258
+ w_bottom_left * depth[i_bottom_left, j_bottom_left] +
259
+ w_bottom_right * depth[i_bottom_right, j_bottom_right])
260
+
261
+ # pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0)
262
+
263
+ pos = pos[:, ids]
264
+
265
+ # =============================================== from d2-net
266
+ pos = pos[[1, 0]].t() # 2xN -> Nx2; h,w(i,j) -> w,h
267
+
268
+ # interpolated_depth: valid interpolated depth
269
+ # pos: valid position (keypoint)
270
+ # ids: indices of valid position (keypoint)
271
+
272
+ return [interpolated_depth, pos, ids, ids_valid_corners, ids_valid_depth]
273
+
274
+
275
+ def to_homogeneous(kpts):
276
+ '''
277
+ :param kpts: Nx2
278
+ :return: Nx3
279
+ '''
280
+ ones = kpts.new_ones([kpts.shape[0], 1])
281
+ return torch.cat((kpts, ones), dim=1)
282
+
283
+
284
+ def warp_homography(kpts0, params):
285
+ '''
286
+ :param kpts: Nx2
287
+ :param homography_matrix: 3x3
288
+ :return:
289
+ '''
290
+ homography_matrix = params['homography_matrix']
291
+ w, h = params['width'], params['height']
292
+ kpts0_homogeneous = to_homogeneous(kpts0)
293
+ kpts01_homogeneous = torch.einsum('ij,kj->ki', homography_matrix, kpts0_homogeneous)
294
+ kpts01 = kpts01_homogeneous[:, :2] / kpts01_homogeneous[:, 2:]
295
+
296
+ kpts01_ = kpts01.detach()
297
+ # due to float coordinates, the upper boundary should be (w-1) and (h-1).
298
+ # For example, if the image size is 480, then the coordinates should in [0~470].
299
+ # 470.5 is not acceptable.
300
+ valid01 = (kpts01_[:, 0] >= 0) * (kpts01_[:, 0] <= w - 1) * (kpts01_[:, 1] >= 0) * (kpts01_[:, 1] <= h - 1)
301
+ kpts0_valid = kpts0[valid01]
302
+ kpts01_valid = kpts01[valid01]
303
+ ids = torch.nonzero(valid01, as_tuple=False)[:, 0]
304
+ ids_out = torch.nonzero(~valid01, as_tuple=False)[:, 0]
305
+
306
+ # kpts0_valid: valid keypoints0, the invalid and inconsistance keypoints are removed
307
+ # kpts01_valid: the warped valid keypoints0
308
+ # ids: the valid indices
309
+ return kpts0_valid, kpts01_valid, ids, ids_out
310
+
311
+
312
+ def project(points3d, K):
313
+ """
314
+ project 3D points to image plane
315
+
316
+ Args:
317
+ points3d: [N,3]
318
+ K: [3,3]
319
+
320
+ Returns:
321
+ uv, (u,v), [N,2]
322
+ """
323
+ if type(K) == torch.Tensor:
324
+ zuv1 = torch.einsum('jk,nk->nj', K, points3d) # z*(u,v,1) = K*points3d -> [N,3]
325
+ elif type(K) == np.ndarray:
326
+ zuv1 = np.einsum('jk,nk->nj', K, points3d)
327
+ else:
328
+ raise TypeError("Input type should be 'torch.tensor' or 'numpy.ndarray'")
329
+ uv1 = zuv1 / zuv1[:, -1][:, None] # (u,v,1) -> [N,3]
330
+ uv = uv1[:, 0:2] # (u,v) -> [N,2]
331
+ return uv, zuv1[:, -1]
332
+
333
+
334
+ def unproject(uv, d, K):
335
+ """
336
+ unproject pixels uv to 3D points
337
+
338
+ Args:
339
+ uv: [N,2]
340
+ d: depth, [N,1]
341
+ K: [3,3]
342
+
343
+ Returns:
344
+ 3D points, [N,3]
345
+ """
346
+ duv = uv * d # (u,v) [N,2]
347
+ if type(K) == torch.Tensor:
348
+ duv1 = torch.cat([duv, d], dim=1) # z*(u,v,1) [N,3]
349
+ K_inv = torch.inverse(K) # [3,3]
350
+ points3d = torch.einsum('jk,nk->nj', K_inv, duv1) # [N,3]
351
+ elif type(K) == np.ndarray:
352
+ duv1 = np.concatenate((duv, d), axis=1) # z*(u,v,1) [N,3]
353
+ K_inv = np.linalg.inv(K) # [3,3]
354
+ points3d = np.einsum('jk,nk->nj', K_inv, duv1) # [N,3]
355
+ else:
356
+ raise TypeError("Input type should be 'torch.tensor' or 'numpy.ndarray'")
357
+ return points3d
358
+
359
+
360
+ def warp_se3(kpts0, params):
361
+ pose01 = params['pose01'] # relative motion
362
+ bbox0 = params['bbox0'] # row, col
363
+ bbox1 = params['bbox1']
364
+ depth0 = params['depth0']
365
+ depth1 = params['depth1']
366
+ intrinsics0 = params['intrinsics0']
367
+ intrinsics1 = params['intrinsics1']
368
+
369
+ # kpts0_valid: valid kpts0
370
+ # z0_valid: depth of valid kpts0
371
+ # ids0: the indices of valid kpts0 ( valid corners and valid depth)
372
+ # ids0_valid_corners: the valid indices of kpts0 in image ( 0<=x<w, 0<=y<h )
373
+ # ids0_valid_depth: the valid indices of kpts0 with valid depth ( depth > 0 )
374
+ z0_valid, kpts0_valid, ids0, ids0_valid_corners, ids0_valid_depth = interpolate_depth(kpts0, depth0)
375
+
376
+ # COLMAP convention
377
+ bkpts0_valid = kpts0_valid + bbox0[[1, 0]][None, :] + 0.5
378
+
379
+ # unproject pixel coordinate to 3D points (camera coordinate system)
380
+ bpoints3d0 = unproject(bkpts0_valid, z0_valid.unsqueeze(1), intrinsics0) # [:,3]
381
+ bpoints3d0_homo = to_homogeneous(bpoints3d0) # [:,4]
382
+
383
+ # warp 3D point (camera 0 coordinate system) to 3D point (camera 1 coordinate system)
384
+ bpoints3d01_homo = torch.einsum('jk,nk->nj', pose01, bpoints3d0_homo) # [:,4]
385
+ bpoints3d01 = bpoints3d01_homo[:, 0:3] # [:,3]
386
+
387
+ # project 3D point (camera coordinate system) to pixel coordinate
388
+ buv01, z01 = project(bpoints3d01, intrinsics1) # uv: [:,2], (h,w); z1: [N]
389
+
390
+ uv01 = buv01 - bbox1[None, [1, 0]] - .5
391
+
392
+ # kpts01_valid: valid kpts01
393
+ # z01_valid: depth of valid kpts01
394
+ # ids01: the indices of valid kpts01 ( valid corners and valid depth)
395
+ # ids01_valid_corners: the valid indices of kpts01 in image ( 0<=x<w, 0<=y<h )
396
+ # ids01_valid_depth: the valid indices of kpts01 with valid depth ( depth > 0 )
397
+ z01_interpolate, kpts01_valid, ids01, ids01_valid_corners, ids01_valid_depth = interpolate_depth(uv01, depth1)
398
+
399
+ outimage_mask = torch.ones(ids0.shape[0], device=ids0.device).bool()
400
+ outimage_mask[ids01_valid_corners] = 0
401
+ ids01_invalid_corners = torch.arange(0, ids0.shape[0], device=ids0.device)[outimage_mask]
402
+ ids_outside = ids0[ids01_invalid_corners]
403
+
404
+ # ids_valid: matched kpts01 without occlusion
405
+ ids_valid = ids0[ids01]
406
+ kpts0_valid = kpts0_valid[ids01]
407
+ z01_proj = z01[ids01]
408
+
409
+ inlier_mask = torch.abs(z01_proj - z01_interpolate) < 0.05
410
+
411
+ # indices of kpts01 with occlusion
412
+ ids_occlude = ids_valid[~inlier_mask]
413
+
414
+ ids_valid = ids_valid[inlier_mask]
415
+ if ids_valid.size(0) == 0:
416
+ # raise ValueError('empty tensor: ids')
417
+ raise EmptyTensorError
418
+
419
+ kpts01_valid = kpts01_valid[inlier_mask]
420
+ kpts0_valid = kpts0_valid[inlier_mask]
421
+
422
+ # indices of kpts01 which are no matches in image1 for sure,
423
+ # other projected kpts01 are not sure because of no depth in image0 or imgae1
424
+ ids_out = torch.cat([ids_outside, ids_occlude])
425
+
426
+ # kpts0_valid: valid keypoints0, the invalid and inconsistance keypoints are removed
427
+ # kpts01_valid: the warped valid keypoints0
428
+ # ids: the valid indices
429
+ return kpts0_valid, kpts01_valid, ids_valid, ids_out
430
+
431
+
432
+ def warp(kpts0, params: dict):
433
+ mode = params['mode']
434
+ if mode == 'homo':
435
+ return warp_homography(kpts0, params)
436
+ elif mode == 'se3':
437
+ return warp_se3(kpts0, params)
438
+ else:
439
+ raise ValueError('unknown mode!')
440
+
441
+
442
+ def warp_xy(kpts0_xy, params: dict):
443
+ w, h = params['width'], params['height']
444
+ kpts0 = (kpts0_xy / 2 + 0.5) * kpts0_xy.new_tensor([[w - 1, h - 1]])
445
+ kpts0, kpts01, ids = warp(kpts0, params)
446
+ kpts01_xy = 2 * kpts01 / kpts01.new_tensor([[w - 1, h - 1]]) - 1
447
+ kpts0_xy = 2 * kpts0 / kpts0.new_tensor([[w - 1, h - 1]]) - 1
448
+ return kpts0_xy, kpts01_xy, ids
449
+
450
+
451
+ def scale_intrinsics(K, scales):
452
+ scales = np.diag([1. / scales[0], 1. / scales[1], 1.])
453
+ return np.dot(scales, K)
454
+
455
+
456
+ def warp_points3d(points3d0, pose01):
457
+ points3d0_homo = np.concatenate((points3d0, np.ones(points3d0.shape[0])[:, np.newaxis]), axis=1) # [:,4]
458
+
459
+ points3d01_homo = np.einsum('jk,nk->nj', pose01, points3d0_homo) # [N,4]
460
+ points3d01 = points3d01_homo[:, 0:3] # [N,3]
461
+
462
+ return points3d01
463
+
464
+
465
+ def unproject_depth(depth, K):
466
+ h, w = depth.shape
467
+
468
+ wh_range = np.mgrid[0:w, 0:h].transpose(2, 1, 0) # [H,W,2]
469
+
470
+ uv = wh_range.reshape(-1, 2)
471
+ d = depth.reshape(-1, 1)
472
+ points3d = unproject(uv, d, K)
473
+
474
+ valid = np.logical_and((d[:, 0] > 0), (points3d[:, 2] > 0))
475
+
476
+ return points3d, valid
477
+
478
+
479
+ @jit(nopython=True)
480
+ def project_depth_nn_python(uv, z, depth):
481
+ h, w = depth.shape
482
+ # TODO: speed up the for loop
483
+ for idx in range(len(uv)):
484
+ uvi = uv[idx]
485
+ x = int(round(uvi[0]))
486
+ y = int(round(uvi[1]))
487
+
488
+ if x < 0 or y < 0 or x >= w or y >= h:
489
+ continue
490
+
491
+ if depth[y, x] == 0. or depth[y, x] > z[idx]:
492
+ depth[y, x] = z[idx]
493
+ return depth
494
+
495
+
496
+ def project_nn(uv, z, depth):
497
+ """
498
+ uv: pixel coordinates [N,2]
499
+ z: projected depth (xyz -> z) [N]
500
+ depth: output depth array: [h,w]
501
+ """
502
+ if nn_cython:
503
+ return project_depth_nn_cython(uv.astype(np.float64),
504
+ z.astype(np.float64),
505
+ depth.astype(np.float64))
506
+ else:
507
+ return project_depth_nn_python(uv, z, depth)
508
+
509
+
510
+ def warp_depth(depth0, intrinsics0, intrinsics1, pose01, shape1):
511
+ points3d0, valid0 = unproject_depth(depth0, intrinsics0) # [:,3]
512
+ points3d0 = points3d0[valid0]
513
+
514
+ points3d01 = warp_points3d(points3d0, pose01)
515
+
516
+ uv01, z01 = project(points3d01, intrinsics1) # uv: [N,2], (h,w); z1: [N]
517
+
518
+ depth01 = project_nn(uv01, z01, depth=np.zeros(shape=shape1))
519
+
520
+ return depth01
521
+
522
+
523
+ def warp_points2d(uv0, d0, intrinsics0, intrinsics1, pose01):
524
+ points3d0 = unproject(uv0, d0, intrinsics0)
525
+ points3d01 = warp_points3d(points3d0, pose01)
526
+ uv01, z01 = project(points3d01, intrinsics1)
527
+ return uv01, z01
528
+
529
+
530
+ def display_image_in_actual_size(image):
531
+ import matplotlib.pyplot as plt
532
+
533
+ dpi = 100
534
+ height, width = image.shape[:2]
535
+
536
+ # What size does the figure need to be in inches to fit the image?
537
+ figsize = width / float(dpi), height / float(dpi)
538
+
539
+ # Create a figure of the right size with one axes that takes up the full figure
540
+ fig = plt.figure(figsize=figsize)
541
+ ax = fig.add_axes([0, 0, 1, 1])
542
+
543
+ # Hide spines, ticks, etc.
544
+ ax.axis('off')
545
+
546
+ # Display the image.
547
+ if len(image.shape) == 3:
548
+ ax.imshow(image, cmap='gray')
549
+ elif len(image.shape) == 2:
550
+ if image.dtype == np.uint8:
551
+ ax.imshow(image, cmap='gray')
552
+ else:
553
+ ax.imshow(image)
554
+ ax.text(20, 20, f"Range: {image.min():g}~{image.max():g}", color='red')
555
+ plt.show()
556
+
557
+
558
+ # ====================================== copied from ASLFeat
559
+ from datetime import datetime
560
+
561
+
562
+ class ClassProperty(property):
563
+ """For dynamically obtaining system time"""
564
+
565
+ def __get__(self, cls, owner):
566
+ return classmethod(self.fget).__get__(None, owner)()
567
+
568
+
569
+ class Notify(object):
570
+ """Colorful printing prefix.
571
+ A quick example:
572
+ print(Notify.INFO, YOUR TEXT, Notify.ENDC)
573
+ """
574
+
575
+ def __init__(self):
576
+ pass
577
+
578
+ @ClassProperty
579
+ def HEADER(cls):
580
+ return str(datetime.now()) + ': \033[95m'
581
+
582
+ @ClassProperty
583
+ def INFO(cls):
584
+ return str(datetime.now()) + ': \033[92mI'
585
+
586
+ @ClassProperty
587
+ def OKBLUE(cls):
588
+ return str(datetime.now()) + ': \033[94m'
589
+
590
+ @ClassProperty
591
+ def WARNING(cls):
592
+ return str(datetime.now()) + ': \033[93mW'
593
+
594
+ @ClassProperty
595
+ def FAIL(cls):
596
+ return str(datetime.now()) + ': \033[91mF'
597
+
598
+ @ClassProperty
599
+ def BOLD(cls):
600
+ return str(datetime.now()) + ': \033[1mB'
601
+
602
+ @ClassProperty
603
+ def UNDERLINE(cls):
604
+ return str(datetime.now()) + ': \033[4mU'
605
+
606
+ ENDC = '\033[0m'
607
+
608
+ def get_essential(T0, T1):
609
+ R0 = T0[:3, :3]
610
+ R1 = T1[:3, :3]
611
+
612
+ t0 = T0[:3, 3].reshape(3, 1)
613
+ t1 = T1[:3, 3].reshape(3, 1)
614
+
615
+ R0 = torch.tensor(R0, dtype=torch.float32)
616
+ R1 = torch.tensor(R1, dtype=torch.float32)
617
+ t0 = torch.tensor(t0, dtype=torch.float32)
618
+ t1 = torch.tensor(t1, dtype=torch.float32)
619
+
620
+ E = essential_from_Rt(R0, t0, R1, t1)
621
+
622
+ return E
623
+
624
+ def get_fundamental(E, K0, K1):
625
+ F = fundamental_from_essential(E, K0, K1)
626
+
627
+ return F
628
+ try:
629
+ # for internel use only
630
+ from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT
631
+ except Exception:
632
+ MEGADEPTH_CLIENT = SCANNET_CLIENT = None
633
+
634
+ # --- DATA IO ---
635
+
636
+ def load_array_from_s3(
637
+ path, client, cv_type,
638
+ use_h5py=False,
639
+ ):
640
+ byte_str = client.Get(path)
641
+ try:
642
+ if not use_h5py:
643
+ raw_array = np.fromstring(byte_str, np.uint8)
644
+ data = cv2.imdecode(raw_array, cv_type)
645
+ else:
646
+ f = io.BytesIO(byte_str)
647
+ data = np.array(h5py.File(f, 'r')['/depth'])
648
+ except Exception as ex:
649
+ print(f"==> Data loading failure: {path}")
650
+ raise ex
651
+
652
+ assert data is not None
653
+ return data
654
+
655
+
656
+ def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
657
+ cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \
658
+ else cv2.IMREAD_COLOR
659
+ if str(path).startswith('s3://'):
660
+ image = load_array_from_s3(str(path), client, cv_type)
661
+ else:
662
+ image = cv2.imread(str(path), 1)
663
+
664
+ if augment_fn is not None:
665
+ image = cv2.imread(str(path), cv2.IMREAD_COLOR)
666
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
667
+ image = augment_fn(image)
668
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
669
+ return image # (h, w)
670
+
671
+
672
+ def get_resized_wh(w, h, resize=None):
673
+ if resize is not None: # resize the longer edge
674
+ scale = resize / max(h, w)
675
+ w_new, h_new = int(round(w*scale)), int(round(h*scale))
676
+ else:
677
+ w_new, h_new = w, h
678
+ return w_new, h_new
679
+
680
+
681
+ def get_divisible_wh(w, h, df=None):
682
+ if df is not None:
683
+ w_new, h_new = map(lambda x: int(x // df * df), [w, h])
684
+ else:
685
+ w_new, h_new = w, h
686
+ return w_new, h_new
687
+
688
+
689
+ def pad_bottom_right(inp, pad_size, ret_mask=False):
690
+ assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
691
+ mask = None
692
+ if inp.ndim == 2:
693
+ padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
694
+ padded[:inp.shape[0], :inp.shape[1]] = inp
695
+ if ret_mask:
696
+ mask = np.zeros((pad_size, pad_size), dtype=bool)
697
+ mask[:inp.shape[0], :inp.shape[1]] = True
698
+ elif inp.ndim == 3:
699
+ padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
700
+ padded[:, :inp.shape[1], :inp.shape[2]] = inp
701
+ if ret_mask:
702
+ mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
703
+ mask[:, :inp.shape[1], :inp.shape[2]] = True
704
+ else:
705
+ raise NotImplementedError()
706
+ return padded, mask
707
+
708
+
709
+ # --- MEGADEPTH ---
710
+
711
+ def fix_path_from_d2net(path):
712
+ if not path:
713
+ return None
714
+
715
+ path = path.replace('Undistorted_SfM/', '')
716
+ path = path.replace('images', 'dense0/imgs')
717
+ path = path.replace('phoenix/S6/zl548/MegaDepth_v1/', '')
718
+
719
+ return path
720
+
721
+ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None):
722
+ """
723
+ Args:
724
+ resize (int, optional): the longer edge of resized images. None for no resize.
725
+ padding (bool): If set to 'True', zero-pad resized images to squared size.
726
+ augment_fn (callable, optional): augments images with pre-defined visual effects
727
+ Returns:
728
+ image (torch.tensor): (1, h, w)
729
+ mask (torch.tensor): (h, w)
730
+ scale (torch.tensor): [w/w_new, h/h_new]
731
+ """
732
+ # read image
733
+ image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT)
734
+
735
+ # resize image
736
+ w, h = image.shape[1], image.shape[0]
737
+
738
+ if len(resize) == 2:
739
+ w_new, h_new = resize
740
+ else:
741
+ resize = resize[0]
742
+ w_new, h_new = get_resized_wh(w, h, resize)
743
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
744
+
745
+
746
+ image = cv2.resize(image, (w_new, h_new))
747
+ scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
748
+
749
+ if padding: # padding
750
+ pad_to = max(h_new, w_new)
751
+ image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
752
+ else:
753
+ mask = None
754
+
755
+ #image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
756
+ image = torch.from_numpy(image).float().permute(2,0,1) / 255 # (h, w) -> (1, h, w) and normalized
757
+ mask = torch.from_numpy(mask) if mask is not None else None
758
+
759
+ return image, mask, scale
760
+
761
+ def imread_color(path, augment_fn=None, client=SCANNET_CLIENT):
762
+ cv_type = cv2.IMREAD_COLOR
763
+ # if str(path).startswith('s3://'):
764
+ # image = load_array_from_s3(str(path), client, cv_type)
765
+ # else:
766
+ # image = cv2.imread(str(path), cv_type)
767
+
768
+ image = cv2.imread(str(path), cv2.IMREAD_COLOR)
769
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
770
+ if augment_fn is not None:
771
+ image = augment_fn(image)
772
+ # image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
773
+ return image # (3, h, w)
774
+
775
+
776
+ def read_megadepth_color(path,
777
+ resize=None,
778
+ df=None,
779
+ padding=False,
780
+ augment_fn=None,
781
+ rotation=0):
782
+ """
783
+ Args:
784
+ resize (int, optional): the longer edge of resized images. None for no resize.
785
+ padding (bool): If set to 'True', zero-pad resized images to squared size.
786
+ augment_fn (callable, optional): augments images with pre-defined visual effects
787
+ Returns:
788
+ image (torch.tensor): (3, h, w)
789
+ mask (torch.tensor): (h, w)
790
+ scale (torch.tensor): [w/w_new, h/h_new]
791
+ """
792
+ # read image
793
+ image = imread_color(path, augment_fn, client=MEGADEPTH_CLIENT)
794
+
795
+ if rotation != 0:
796
+ image = np.rot90(image, k=rotation).copy()
797
+
798
+ # resize image
799
+ if resize is not None:
800
+ w, h = image.shape[1], image.shape[0]
801
+ if len(resize) == 2:
802
+ w_new, h_new = resize
803
+ else:
804
+ resize = resize[0]
805
+ w_new, h_new = get_resized_wh(w, h, resize)
806
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
807
+
808
+
809
+ image = cv2.resize(image, (w_new, h_new))
810
+ scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
811
+ scale_wh = torch.tensor([w_new, h_new], dtype=torch.float)
812
+ else:
813
+ scale = torch.tensor([1., 1.], dtype=torch.float)
814
+ scale_wh = torch.tensor([image.shape[1], image.shape[0]], dtype=torch.float)
815
+
816
+ image = image.transpose(2, 0, 1)
817
+
818
+ if padding: # padding
819
+ if resize is not None:
820
+ pad_to = max(h_new, w_new)
821
+ else:
822
+ pad_to = 2000
823
+ image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
824
+ else:
825
+ mask = None
826
+
827
+ image = torch.from_numpy(image).float() / 255 # (h, w) -> (1, h, w) and normalized
828
+ mask = torch.from_numpy(mask) if mask is not None else None
829
+
830
+ return image, mask, scale
831
+
832
+ def read_megadepth_depth(path, pad_to=None):
833
+
834
+ if str(path).startswith('s3://'):
835
+ depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True)
836
+ else:
837
+ depth = np.array(h5py.File(path, 'r')['depth'])
838
+ if pad_to is not None:
839
+ depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
840
+ depth = torch.from_numpy(depth).float() # (h, w)
841
+ return depth
842
+
843
+ def get_image_name(path):
844
+ return path.split('/')[-1]
845
+
846
+ def scale_intrinsics(K, scales):
847
+ scales = np.diag([1. / scales[0], 1. / scales[1], 1.])
848
+ return np.dot(scales, K)
imcui/third_party/rdd/RDD/matchers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .dual_softmax_matcher import DualSoftmaxMatcher
2
+ from .dense_matcher import DenseMatcher
3
+ from .lightglue import LightGlue
imcui/third_party/rdd/RDD/matchers/dense_matcher.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import poselib
5
+
6
+ class DenseMatcher(nn.Module):
7
+ def __init__(self, inv_temperature = 20, thr = 0.01):
8
+ super().__init__()
9
+ self.inv_temperature = inv_temperature
10
+ self.thr = thr
11
+
12
+ def forward(self, info0, info1, thr = None, err_thr=4, min_num_inliers=30):
13
+
14
+ desc0 = info0['descriptors']
15
+ desc1 = info1['descriptors']
16
+
17
+ inds, P = self.dual_softmax(desc0, desc1, thr=thr)
18
+
19
+ mkpts_0 = info0['keypoints'][inds[:,0]]
20
+ mkpts_1 = info1['keypoints'][inds[:,1]]
21
+ mconf = P[inds[:,0], inds[:,1]]
22
+ Fm, inliers = self.get_fundamental_matrix(mkpts_0, mkpts_1)
23
+
24
+ if inliers.sum() >= min_num_inliers:
25
+ desc1_dense = info0['descriptors_dense']
26
+ desc2_dense = info1['descriptors_dense']
27
+
28
+ inds_dense, P_dense = self.dual_softmax(desc1_dense, desc2_dense, thr=thr)
29
+
30
+ mkpts_0_dense = info0['keypoints_dense'][inds_dense[:,0]]
31
+ mkpts_1_dense = info1['keypoints_dense'][inds_dense[:,1]]
32
+ mconf_dense = P_dense[inds_dense[:,0], inds_dense[:,1]]
33
+
34
+ mkpts_0_dense, mkpts_1_dense, mconf_dense = self.refine_matches(mkpts_0_dense, mkpts_1_dense, mconf_dense, Fm, err_thr=err_thr)
35
+ mkpts_0 = mkpts_0[inliers]
36
+ mkpts_1 = mkpts_1[inliers]
37
+ mconf = mconf[inliers]
38
+ # concatenate the matches
39
+ mkpts_0 = torch.cat([mkpts_0, mkpts_0_dense], dim=0)
40
+ mkpts_1 = torch.cat([mkpts_1, mkpts_1_dense], dim=0)
41
+ mconf = torch.cat([mconf, mconf_dense], dim=0)
42
+
43
+ return mkpts_0, mkpts_1, mconf
44
+
45
+ def get_fundamental_matrix(self, kpts_0, kpts_1):
46
+ Fm, info = poselib.estimate_fundamental(kpts_0.cpu().numpy(), kpts_1.cpu().numpy(), {'max_epipolar_error': 1, 'progressive_sampling': True}, {})
47
+ inliers = info['inliers']
48
+ Fm = torch.tensor(Fm, device=kpts_0.device, dtype=kpts_0.dtype)
49
+ inliers = torch.tensor(inliers, device=kpts_0.device, dtype=torch.bool)
50
+ return Fm, inliers
51
+
52
+ def dual_softmax(self, desc0, desc1, thr = None):
53
+ if thr is None:
54
+ thr = self.thr
55
+ dist_mat = (desc0 @ desc1.t()) * self.inv_temperature
56
+ P = dist_mat.softmax(dim = -2) * dist_mat.softmax(dim= -1)
57
+ inds = torch.nonzero((P == P.max(dim=-1, keepdim = True).values)
58
+ * (P == P.max(dim=-2, keepdim = True).values) * (P >= thr))
59
+
60
+ return inds, P
61
+
62
+ @torch.inference_mode()
63
+ def refine_matches(self, mkpts_0, mkpts_1, mconf, Fm, err_thr=4):
64
+ mkpts_0_h = torch.cat([mkpts_0, torch.ones(mkpts_0.shape[0], 1, device=mkpts_0.device)], dim=1) # (N, 3)
65
+ mkpts_1_h = torch.cat([mkpts_1, torch.ones(mkpts_1.shape[0], 1, device=mkpts_1.device)], dim=1) # (N, 3)
66
+
67
+ lines_1 = torch.matmul(Fm, mkpts_0_h.T).T
68
+
69
+ a, b, c = lines_1[:, 0], lines_1[:, 1], lines_1[:, 2]
70
+
71
+ x1, y1 = mkpts_1[:, 0], mkpts_1[:, 1]
72
+
73
+ denom = a**2 + b**2 + 1e-8
74
+
75
+ x_offset = (b * (b * x1 - a * y1) - a * c) / denom - x1
76
+ y_offset = (a * (a * y1 - b * x1) - b * c) / denom - y1
77
+
78
+ inds = (x_offset.abs() < err_thr) | (y_offset.abs() < err_thr)
79
+
80
+ x_offset = x_offset[inds]
81
+ y_offset = y_offset[inds]
82
+
83
+ mkpts_0 = mkpts_0[inds]
84
+ mkpts_1 = mkpts_1[inds]
85
+
86
+ refined_mkpts_1 = mkpts_1 + torch.stack([x_offset, y_offset], dim=1)
87
+
88
+ return mkpts_0, refined_mkpts_1, mconf[inds]
imcui/third_party/rdd/RDD/matchers/dual_softmax_matcher.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class DualSoftmaxMatcher(nn.Module):
6
+ def __init__(self, inv_temperature = 20, thr = 0.01):
7
+ super().__init__()
8
+ self.inv_temperature = inv_temperature
9
+ self.thr = thr
10
+
11
+ def forward(self, info0, info1, thr = None):
12
+ desc0 = info0['descriptors']
13
+ desc1 = info1['descriptors']
14
+
15
+ inds, P = self.dual_softmax(desc0, desc1, thr)
16
+ mkpts0 = info0['keypoints'][inds[:,0]]
17
+ mkpts1 = info1['keypoints'][inds[:,1]]
18
+ mconf = P[inds[:,0], inds[:,1]]
19
+
20
+ return mkpts0, mkpts1, mconf
21
+
22
+ def dual_softmax(self, desc0, desc1, thr = None):
23
+ if thr is None:
24
+ thr = self.thr
25
+ dist_mat = (desc0 @ desc1.t()) * self.inv_temperature
26
+ P = dist_mat.softmax(dim = -2) * dist_mat.softmax(dim= -1)
27
+
28
+ inds = torch.nonzero((P == P.max(dim=-1, keepdim = True).values)
29
+ * (P == P.max(dim=-2, keepdim = True).values) * (P >= thr))
30
+
31
+ return inds, P
imcui/third_party/rdd/RDD/matchers/lightglue.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified from
3
+ https://github.com/cvg/LightGlue
4
+ """
5
+
6
+ import warnings
7
+ from pathlib import Path
8
+ from types import SimpleNamespace
9
+ from typing import Callable, List, Optional, Tuple
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn
15
+
16
+ try:
17
+ from flash_attn.modules.mha import FlashCrossAttention
18
+ except ModuleNotFoundError:
19
+ FlashCrossAttention = None
20
+
21
+ if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
22
+ FLASH_AVAILABLE = True
23
+ else:
24
+ FLASH_AVAILABLE = False
25
+
26
+ torch.backends.cudnn.deterministic = True
27
+
28
+
29
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
30
+ def normalize_keypoints(
31
+ kpts: torch.Tensor, size: Optional[torch.Tensor] = None
32
+ ) -> torch.Tensor:
33
+ if size is None:
34
+ size = 1 + kpts.max(-2).values - kpts.min(-2).values
35
+ elif not isinstance(size, torch.Tensor):
36
+ size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
37
+ size = size.to(kpts)
38
+ shift = size / 2
39
+ scale = size.max(-1).values / 2
40
+ kpts = (kpts - shift[..., None, :]) / scale[..., None, None]
41
+ return kpts
42
+
43
+
44
+ def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
45
+ if length <= x.shape[-2]:
46
+ return x, torch.ones_like(x[..., :1], dtype=torch.bool)
47
+ pad = torch.ones(
48
+ *x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
49
+ )
50
+ y = torch.cat([x, pad], dim=-2)
51
+ mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
52
+ mask[..., : x.shape[-2], :] = True
53
+ return y, mask
54
+
55
+
56
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
57
+ x = x.unflatten(-1, (-1, 2))
58
+ x1, x2 = x.unbind(dim=-1)
59
+ return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
60
+
61
+
62
+ def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
63
+ return (t * freqs[0]) + (rotate_half(t) * freqs[1])
64
+
65
+
66
+ class LearnableFourierPositionalEncoding(nn.Module):
67
+ def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
68
+ super().__init__()
69
+ F_dim = F_dim if F_dim is not None else dim
70
+ self.gamma = gamma
71
+ self.Wr = nn.Linear(M, F_dim // 2, bias=False)
72
+ nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
73
+
74
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
75
+ """encode position vector"""
76
+ projected = self.Wr(x)
77
+ cosines, sines = torch.cos(projected), torch.sin(projected)
78
+ emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
79
+ return emb.repeat_interleave(2, dim=-1)
80
+
81
+
82
+ class TokenConfidence(nn.Module):
83
+ def __init__(self, dim: int) -> None:
84
+ super().__init__()
85
+ self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
86
+
87
+ def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
88
+ """get confidence tokens"""
89
+ return (
90
+ self.token(desc0.detach()).squeeze(-1),
91
+ self.token(desc1.detach()).squeeze(-1),
92
+ )
93
+
94
+
95
+ class Attention(nn.Module):
96
+ def __init__(self, allow_flash: bool) -> None:
97
+ super().__init__()
98
+ if allow_flash and not FLASH_AVAILABLE:
99
+ warnings.warn(
100
+ "FlashAttention is not available. For optimal speed, "
101
+ "consider installing torch >= 2.0 or flash-attn.",
102
+ stacklevel=2,
103
+ )
104
+ self.enable_flash = allow_flash and FLASH_AVAILABLE
105
+ self.has_sdp = hasattr(F, "scaled_dot_product_attention")
106
+ if allow_flash and FlashCrossAttention:
107
+ self.flash_ = FlashCrossAttention()
108
+ if self.has_sdp:
109
+ torch.backends.cuda.enable_flash_sdp(allow_flash)
110
+
111
+ def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
112
+ if q.shape[-2] == 0 or k.shape[-2] == 0:
113
+ return q.new_zeros((*q.shape[:-1], v.shape[-1]))
114
+ if self.enable_flash and q.device.type == "cuda":
115
+ # use torch 2.0 scaled_dot_product_attention with flash
116
+ if self.has_sdp:
117
+ args = [x.half().contiguous() for x in [q, k, v]]
118
+ v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
119
+ return v if mask is None else v.nan_to_num()
120
+ else:
121
+ assert mask is None
122
+ q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]]
123
+ m = self.flash_(q.half(), torch.stack([k, v], 2).half())
124
+ return m.transpose(-2, -3).to(q.dtype).clone()
125
+ elif self.has_sdp:
126
+ args = [x.contiguous() for x in [q, k, v]]
127
+ v = F.scaled_dot_product_attention(*args, attn_mask=mask)
128
+ return v if mask is None else v.nan_to_num()
129
+ else:
130
+ s = q.shape[-1] ** -0.5
131
+ sim = torch.einsum("...id,...jd->...ij", q, k) * s
132
+ if mask is not None:
133
+ sim.masked_fill(~mask, -float("inf"))
134
+ attn = F.softmax(sim, -1)
135
+ return torch.einsum("...ij,...jd->...id", attn, v)
136
+
137
+
138
+ class SelfBlock(nn.Module):
139
+ def __init__(
140
+ self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
141
+ ) -> None:
142
+ super().__init__()
143
+ self.embed_dim = embed_dim
144
+ self.num_heads = num_heads
145
+ assert self.embed_dim % num_heads == 0
146
+ self.head_dim = self.embed_dim // num_heads
147
+ self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
148
+ self.inner_attn = Attention(flash)
149
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
150
+ self.ffn = nn.Sequential(
151
+ nn.Linear(2 * embed_dim, 2 * embed_dim),
152
+ nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
153
+ nn.GELU(),
154
+ nn.Linear(2 * embed_dim, embed_dim),
155
+ )
156
+
157
+ def forward(
158
+ self,
159
+ x: torch.Tensor,
160
+ encoding: torch.Tensor,
161
+ mask: Optional[torch.Tensor] = None,
162
+ ) -> torch.Tensor:
163
+ qkv = self.Wqkv(x)
164
+ qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
165
+ q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
166
+ q = apply_cached_rotary_emb(encoding, q)
167
+ k = apply_cached_rotary_emb(encoding, k)
168
+ context = self.inner_attn(q, k, v, mask=mask)
169
+ message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
170
+ return x + self.ffn(torch.cat([x, message], -1))
171
+
172
+
173
+ class CrossBlock(nn.Module):
174
+ def __init__(
175
+ self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
176
+ ) -> None:
177
+ super().__init__()
178
+ self.heads = num_heads
179
+ dim_head = embed_dim // num_heads
180
+ self.scale = dim_head**-0.5
181
+ inner_dim = dim_head * num_heads
182
+ self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
183
+ self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
184
+ self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
185
+ self.ffn = nn.Sequential(
186
+ nn.Linear(2 * embed_dim, 2 * embed_dim),
187
+ nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
188
+ nn.GELU(),
189
+ nn.Linear(2 * embed_dim, embed_dim),
190
+ )
191
+ if flash and FLASH_AVAILABLE:
192
+ self.flash = Attention(True)
193
+ else:
194
+ self.flash = None
195
+
196
+ def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
197
+ return func(x0), func(x1)
198
+
199
+ def forward(
200
+ self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
201
+ ) -> List[torch.Tensor]:
202
+ qk0, qk1 = self.map_(self.to_qk, x0, x1)
203
+ v0, v1 = self.map_(self.to_v, x0, x1)
204
+ qk0, qk1, v0, v1 = map(
205
+ lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
206
+ (qk0, qk1, v0, v1),
207
+ )
208
+ if self.flash is not None and qk0.device.type == "cuda":
209
+ m0 = self.flash(qk0, qk1, v1, mask)
210
+ m1 = self.flash(
211
+ qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
212
+ )
213
+ else:
214
+ qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
215
+ sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
216
+ if mask is not None:
217
+ sim = sim.masked_fill(~mask, -float("inf"))
218
+ attn01 = F.softmax(sim, dim=-1)
219
+ attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
220
+ m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
221
+ m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
222
+ if mask is not None:
223
+ m0, m1 = m0.nan_to_num(), m1.nan_to_num()
224
+ m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
225
+ m0, m1 = self.map_(self.to_out, m0, m1)
226
+ x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
227
+ x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
228
+ return x0, x1
229
+
230
+
231
+ class TransformerLayer(nn.Module):
232
+ def __init__(self, *args, **kwargs):
233
+ super().__init__()
234
+ self.self_attn = SelfBlock(*args, **kwargs)
235
+ self.cross_attn = CrossBlock(*args, **kwargs)
236
+
237
+ def forward(
238
+ self,
239
+ desc0,
240
+ desc1,
241
+ encoding0,
242
+ encoding1,
243
+ mask0: Optional[torch.Tensor] = None,
244
+ mask1: Optional[torch.Tensor] = None,
245
+ ):
246
+ if mask0 is not None and mask1 is not None:
247
+ return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
248
+ else:
249
+ desc0 = self.self_attn(desc0, encoding0)
250
+ desc1 = self.self_attn(desc1, encoding1)
251
+ return self.cross_attn(desc0, desc1)
252
+
253
+ # This part is compiled and allows padding inputs
254
+ def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1):
255
+ mask = mask0 & mask1.transpose(-1, -2)
256
+ mask0 = mask0 & mask0.transpose(-1, -2)
257
+ mask1 = mask1 & mask1.transpose(-1, -2)
258
+ desc0 = self.self_attn(desc0, encoding0, mask0)
259
+ desc1 = self.self_attn(desc1, encoding1, mask1)
260
+ return self.cross_attn(desc0, desc1, mask)
261
+
262
+
263
+ def sigmoid_log_double_softmax(
264
+ sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
265
+ ) -> torch.Tensor:
266
+ """create the log assignment matrix from logits and similarity"""
267
+ b, m, n = sim.shape
268
+ certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
269
+ scores0 = F.log_softmax(sim, 2)
270
+ scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
271
+ scores = sim.new_full((b, m + 1, n + 1), 0)
272
+ scores[:, :m, :n] = scores0 + scores1 + certainties
273
+ scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
274
+ scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
275
+ return scores
276
+
277
+
278
+ class MatchAssignment(nn.Module):
279
+ def __init__(self, dim: int) -> None:
280
+ super().__init__()
281
+ self.dim = dim
282
+ self.matchability = nn.Linear(dim, 1, bias=True)
283
+ self.final_proj = nn.Linear(dim, dim, bias=True)
284
+
285
+ def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
286
+ """build assignment matrix from descriptors"""
287
+ mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
288
+ _, _, d = mdesc0.shape
289
+ mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
290
+ sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
291
+ z0 = self.matchability(desc0)
292
+ z1 = self.matchability(desc1)
293
+ scores = sigmoid_log_double_softmax(sim, z0, z1)
294
+ return scores, sim
295
+
296
+ def get_matchability(self, desc: torch.Tensor):
297
+ return torch.sigmoid(self.matchability(desc)).squeeze(-1)
298
+
299
+
300
+ def filter_matches(scores: torch.Tensor, th: float):
301
+ """obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
302
+ max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
303
+ m0, m1 = max0.indices, max1.indices
304
+ indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
305
+ indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
306
+ mutual0 = indices0 == m1.gather(1, m0)
307
+ mutual1 = indices1 == m0.gather(1, m1)
308
+ max0_exp = max0.values.exp()
309
+ zero = max0_exp.new_tensor(0)
310
+ mscores0 = torch.where(mutual0, max0_exp, zero)
311
+ mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
312
+ valid0 = mutual0 & (mscores0 > th)
313
+ valid1 = mutual1 & valid0.gather(1, m1)
314
+ m0 = torch.where(valid0, m0, -1)
315
+ m1 = torch.where(valid1, m1, -1)
316
+ return m0, m1, mscores0, mscores1
317
+
318
+
319
+ class LightGlue(nn.Module):
320
+ default_conf = {
321
+ "name": "lightglue", # just for interfacing
322
+ "input_dim": 256, # input descriptor dimension (autoselected from weights)
323
+ "descriptor_dim": 256,
324
+ "add_scale_ori": False,
325
+ "n_layers": 9,
326
+ "num_heads": 4,
327
+ "flash": True, # enable FlashAttention if available.
328
+ "mp": False, # enable mixed precision
329
+ "depth_confidence": -1, # early stopping, disable with -1
330
+ "width_confidence": -1, # point pruning, disable with -1
331
+ "filter_threshold": 0.01, # match threshold
332
+ "weights": None,
333
+ }
334
+
335
+ # Point pruning involves an overhead (gather).
336
+ # Therefore, we only activate it if there are enough keypoints.
337
+ pruning_keypoint_thresholds = {
338
+ "cpu": -1,
339
+ "mps": -1,
340
+ "cuda": 1024,
341
+ "flash": 1536,
342
+ }
343
+
344
+ required_data_keys = ["image0", "image1"]
345
+
346
+ version = "v0.1_arxiv"
347
+ url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
348
+
349
+ features = {
350
+ "superpoint": {
351
+ "weights": "superpoint_lightglue",
352
+ "input_dim": 256,
353
+ },
354
+ "disk": {
355
+ "weights": "disk_lightglue",
356
+ "input_dim": 128,
357
+ },
358
+ "aliked": {
359
+ "weights": "aliked_lightglue",
360
+ "input_dim": 128,
361
+ },
362
+ "sift": {
363
+ "weights": "sift_lightglue",
364
+ "input_dim": 128,
365
+ "add_scale_ori": True,
366
+ },
367
+ "doghardnet": {
368
+ "weights": "doghardnet_lightglue",
369
+ "input_dim": 128,
370
+ "add_scale_ori": True,
371
+ },
372
+ "rdd": {
373
+ "weights": './weights/RDD_lg-v2.pth',
374
+ "input_dim": 256,
375
+ },
376
+ }
377
+
378
+ def __init__(self, features="rdd", **conf) -> None:
379
+ super().__init__()
380
+ self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
381
+ if features is not None:
382
+ if features not in self.features:
383
+ raise ValueError(
384
+ f"Unsupported features: {features} not in "
385
+ f"{{{','.join(self.features)}}}"
386
+ )
387
+ for k, v in self.features[features].items():
388
+ setattr(conf, k, v)
389
+
390
+ if conf.input_dim != conf.descriptor_dim:
391
+ self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
392
+ else:
393
+ self.input_proj = nn.Identity()
394
+
395
+ head_dim = conf.descriptor_dim // conf.num_heads
396
+ self.posenc = LearnableFourierPositionalEncoding(
397
+ 2 + 2 * self.conf.add_scale_ori, head_dim, head_dim
398
+ )
399
+
400
+ h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
401
+
402
+ self.transformers = nn.ModuleList(
403
+ [TransformerLayer(d, h, conf.flash) for _ in range(n)]
404
+ )
405
+
406
+ self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
407
+ self.token_confidence = nn.ModuleList(
408
+ [TokenConfidence(d) for _ in range(n - 1)]
409
+ )
410
+ self.register_buffer(
411
+ "confidence_thresholds",
412
+ torch.Tensor(
413
+ [self.confidence_threshold(i) for i in range(self.conf.n_layers)]
414
+ ),
415
+ )
416
+
417
+ state_dict = None
418
+ if features is not None and features != 'rdd':
419
+ fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth"
420
+ state_dict = torch.hub.load_state_dict_from_url(
421
+ self.url.format(self.version, features), file_name=fname
422
+ )
423
+ self.load_state_dict(state_dict, strict=False)
424
+ elif conf.weights is not None:
425
+ if features == 'rdd':
426
+ path = Path(conf.weights)
427
+ else:
428
+ path = Path(__file__).parent
429
+ path = path / "weights/{}.pth".format(self.conf.weights)
430
+ state_dict = torch.load(str(path), map_location="cpu")
431
+
432
+ if state_dict:
433
+ # rename old state dict entries
434
+ for i in range(self.conf.n_layers):
435
+ pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
436
+ state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
437
+ pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
438
+ state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
439
+ self.load_state_dict(state_dict, strict=False)
440
+
441
+ # static lengths LightGlue is compiled for (only used with torch.compile)
442
+ self.static_lengths = None
443
+
444
+ def compile(
445
+ self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
446
+ ):
447
+ if self.conf.width_confidence != -1:
448
+ warnings.warn(
449
+ "Point pruning is partially disabled for compiled forward.",
450
+ stacklevel=2,
451
+ )
452
+
453
+ torch._inductor.cudagraph_mark_step_begin()
454
+ for i in range(self.conf.n_layers):
455
+ self.transformers[i].masked_forward = torch.compile(
456
+ self.transformers[i].masked_forward, mode=mode, fullgraph=True
457
+ )
458
+
459
+ self.static_lengths = static_lengths
460
+
461
+ def forward(self, data: dict) -> dict:
462
+ """
463
+ Match keypoints and descriptors between two images
464
+
465
+ Input (dict):
466
+ image0: dict
467
+ keypoints: [B x M x 2]
468
+ descriptors: [B x M x D]
469
+ image: [B x C x H x W] or image_size: [B x 2]
470
+ image1: dict
471
+ keypoints: [B x N x 2]
472
+ descriptors: [B x N x D]
473
+ image: [B x C x H x W] or image_size: [B x 2]
474
+ Output (dict):
475
+ matches0: [B x M]
476
+ matching_scores0: [B x M]
477
+ matches1: [B x N]
478
+ matching_scores1: [B x N]
479
+ matches: List[[Si x 2]]
480
+ scores: List[[Si]]
481
+ stop: int
482
+ prune0: [B x M]
483
+ prune1: [B x N]
484
+ """
485
+ with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
486
+ return self._forward(data)
487
+
488
+ def _forward(self, data: dict) -> dict:
489
+ for key in self.required_data_keys:
490
+ assert key in data, f"Missing key {key} in data"
491
+ data0, data1 = data["image0"], data["image1"]
492
+ kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
493
+ b, m, _ = kpts0.shape
494
+ b, n, _ = kpts1.shape
495
+ device = kpts0.device
496
+ size0, size1 = data0.get("image_size"), data1.get("image_size")
497
+ kpts0 = normalize_keypoints(kpts0, size0).clone()
498
+ kpts1 = normalize_keypoints(kpts1, size1).clone()
499
+
500
+ if self.conf.add_scale_ori:
501
+ kpts0 = torch.cat(
502
+ [kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
503
+ )
504
+ kpts1 = torch.cat(
505
+ [kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
506
+ )
507
+ desc0 = data0["descriptors"].detach().contiguous()
508
+ desc1 = data1["descriptors"].detach().contiguous()
509
+
510
+ assert desc0.shape[-1] == self.conf.input_dim
511
+ assert desc1.shape[-1] == self.conf.input_dim
512
+
513
+ if torch.is_autocast_enabled():
514
+ desc0 = desc0.half()
515
+ desc1 = desc1.half()
516
+
517
+ mask0, mask1 = None, None
518
+ c = max(m, n)
519
+ do_compile = self.static_lengths and c <= max(self.static_lengths)
520
+ if do_compile:
521
+ kn = min([k for k in self.static_lengths if k >= c])
522
+ desc0, mask0 = pad_to_length(desc0, kn)
523
+ desc1, mask1 = pad_to_length(desc1, kn)
524
+ kpts0, _ = pad_to_length(kpts0, kn)
525
+ kpts1, _ = pad_to_length(kpts1, kn)
526
+ desc0 = self.input_proj(desc0)
527
+ desc1 = self.input_proj(desc1)
528
+ # cache positional embeddings
529
+ encoding0 = self.posenc(kpts0)
530
+ encoding1 = self.posenc(kpts1)
531
+
532
+ # GNN + final_proj + assignment
533
+ do_early_stop = self.conf.depth_confidence > 0
534
+ do_point_pruning = self.conf.width_confidence > 0 and not do_compile
535
+ pruning_th = self.pruning_min_kpts(device)
536
+ if do_point_pruning:
537
+ ind0 = torch.arange(0, m, device=device)[None]
538
+ ind1 = torch.arange(0, n, device=device)[None]
539
+ # We store the index of the layer at which pruning is detected.
540
+ prune0 = torch.ones_like(ind0)
541
+ prune1 = torch.ones_like(ind1)
542
+ token0, token1 = None, None
543
+ for i in range(self.conf.n_layers):
544
+ if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
545
+ break
546
+ desc0, desc1 = self.transformers[i](
547
+ desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
548
+ )
549
+ if i == self.conf.n_layers - 1:
550
+ continue # no early stopping or adaptive width at last layer
551
+
552
+ if do_early_stop:
553
+ token0, token1 = self.token_confidence[i](desc0, desc1)
554
+ if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n):
555
+ break
556
+ if do_point_pruning and desc0.shape[-2] > pruning_th:
557
+ scores0 = self.log_assignment[i].get_matchability(desc0)
558
+ prunemask0 = self.get_pruning_mask(token0, scores0, i)
559
+ keep0 = torch.where(prunemask0)[1]
560
+ ind0 = ind0.index_select(1, keep0)
561
+ desc0 = desc0.index_select(1, keep0)
562
+ encoding0 = encoding0.index_select(-2, keep0)
563
+ prune0[:, ind0] += 1
564
+ if do_point_pruning and desc1.shape[-2] > pruning_th:
565
+ scores1 = self.log_assignment[i].get_matchability(desc1)
566
+ prunemask1 = self.get_pruning_mask(token1, scores1, i)
567
+ keep1 = torch.where(prunemask1)[1]
568
+ ind1 = ind1.index_select(1, keep1)
569
+ desc1 = desc1.index_select(1, keep1)
570
+ encoding1 = encoding1.index_select(-2, keep1)
571
+ prune1[:, ind1] += 1
572
+
573
+ if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
574
+ m0 = desc0.new_full((b, m), -1, dtype=torch.long)
575
+ m1 = desc1.new_full((b, n), -1, dtype=torch.long)
576
+ mscores0 = desc0.new_zeros((b, m))
577
+ mscores1 = desc1.new_zeros((b, n))
578
+ matches = desc0.new_empty((b, 0, 2), dtype=torch.long)
579
+ mscores = desc0.new_empty((b, 0))
580
+ if not do_point_pruning:
581
+ prune0 = torch.ones_like(mscores0) * self.conf.n_layers
582
+ prune1 = torch.ones_like(mscores1) * self.conf.n_layers
583
+ return {
584
+ "matches0": m0,
585
+ "matches1": m1,
586
+ "matching_scores0": mscores0,
587
+ "matching_scores1": mscores1,
588
+ "stop": i + 1,
589
+ "matches": matches,
590
+ "scores": mscores,
591
+ "prune0": prune0,
592
+ "prune1": prune1,
593
+ }
594
+
595
+ desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding
596
+ scores, _ = self.log_assignment[i](desc0, desc1)
597
+ m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
598
+ matches, mscores = [], []
599
+ for k in range(b):
600
+ valid = m0[k] > -1
601
+ m_indices_0 = torch.where(valid)[0]
602
+ m_indices_1 = m0[k][valid]
603
+ if do_point_pruning:
604
+ m_indices_0 = ind0[k, m_indices_0]
605
+ m_indices_1 = ind1[k, m_indices_1]
606
+ matches.append(torch.stack([m_indices_0, m_indices_1], -1))
607
+ mscores.append(mscores0[k][valid])
608
+
609
+ # TODO: Remove when hloc switches to the compact format.
610
+ if do_point_pruning:
611
+ m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
612
+ m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
613
+ m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
614
+ m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
615
+ mscores0_ = torch.zeros((b, m), device=mscores0.device)
616
+ mscores1_ = torch.zeros((b, n), device=mscores1.device)
617
+ mscores0_[:, ind0] = mscores0
618
+ mscores1_[:, ind1] = mscores1
619
+ m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
620
+ else:
621
+ prune0 = torch.ones_like(mscores0) * self.conf.n_layers
622
+ prune1 = torch.ones_like(mscores1) * self.conf.n_layers
623
+
624
+ return {
625
+ "matches0": m0,
626
+ "matches1": m1,
627
+ "matching_scores0": mscores0,
628
+ "matching_scores1": mscores1,
629
+ "stop": i + 1,
630
+ "matches": matches,
631
+ "scores": mscores,
632
+ "prune0": prune0,
633
+ "prune1": prune1,
634
+ }
635
+
636
+ def confidence_threshold(self, layer_index: int) -> float:
637
+ """scaled confidence threshold"""
638
+ threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
639
+ return np.clip(threshold, 0, 1)
640
+
641
+ def get_pruning_mask(
642
+ self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
643
+ ) -> torch.Tensor:
644
+ """mask points which should be removed"""
645
+ keep = scores > (1 - self.conf.width_confidence)
646
+ if confidences is not None: # Low-confidence points are never pruned.
647
+ keep |= confidences <= self.confidence_thresholds[layer_index]
648
+ return keep
649
+
650
+ def check_if_stop(
651
+ self,
652
+ confidences0: torch.Tensor,
653
+ confidences1: torch.Tensor,
654
+ layer_index: int,
655
+ num_points: int,
656
+ ) -> torch.Tensor:
657
+ """evaluate stopping condition"""
658
+ confidences = torch.cat([confidences0, confidences1], -1)
659
+ threshold = self.confidence_thresholds[layer_index]
660
+ ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
661
+ return ratio_confident > self.conf.depth_confidence
662
+
663
+ def pruning_min_kpts(self, device: torch.device):
664
+ if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
665
+ return self.pruning_keypoint_thresholds["flash"]
666
+ else:
667
+ return self.pruning_keypoint_thresholds[device.type]
imcui/third_party/rdd/RDD/models/backbone.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Deformable DETR
2
+ # https://github.com/fundamentalvision/Deformable-DETR
3
+ # ------------------------------------------------------------------------
4
+ # Deformable DETR
5
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
6
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
7
+ # ------------------------------------------------------------------------
8
+ # Modified from DETR (https://github.com/facebookresearch/detr)
9
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
10
+ # ------------------------------------------------------------------------
11
+
12
+ """
13
+ Backbone modules.
14
+ """
15
+ from collections import OrderedDict
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import torchvision
20
+ from torch import nn
21
+ from torchvision.models._utils import IntermediateLayerGetter
22
+ from typing import Dict, List
23
+ import torch.distributed as dist
24
+ from .position_encoding import build_position_encoding
25
+
26
+ from ..utils.misc import NestedTensor, is_main_process
27
+
28
+ class FrozenBatchNorm2d(torch.nn.Module):
29
+ """
30
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
31
+
32
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
33
+ without which any other models than torchvision.models.resnet[18,34,50,101]
34
+ produce nans.
35
+ """
36
+
37
+ def __init__(self, n, eps=1e-5):
38
+ super(FrozenBatchNorm2d, self).__init__()
39
+ self.register_buffer("weight", torch.ones(n))
40
+ self.register_buffer("bias", torch.zeros(n))
41
+ self.register_buffer("running_mean", torch.zeros(n))
42
+ self.register_buffer("running_var", torch.ones(n))
43
+ self.eps = eps
44
+
45
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
46
+ missing_keys, unexpected_keys, error_msgs):
47
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
48
+ if num_batches_tracked_key in state_dict:
49
+ del state_dict[num_batches_tracked_key]
50
+
51
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
52
+ state_dict, prefix, local_metadata, strict,
53
+ missing_keys, unexpected_keys, error_msgs)
54
+
55
+ def forward(self, x):
56
+ # move reshapes to the beginning
57
+ # to make it fuser-friendly
58
+ w = self.weight.reshape(1, -1, 1, 1)
59
+ b = self.bias.reshape(1, -1, 1, 1)
60
+ rv = self.running_var.reshape(1, -1, 1, 1)
61
+ rm = self.running_mean.reshape(1, -1, 1, 1)
62
+ eps = self.eps
63
+ scale = w * (rv + eps).rsqrt()
64
+ bias = b - rm * scale
65
+ return x * scale + bias
66
+
67
+
68
+ class BackboneBase(nn.Module):
69
+
70
+ def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool, n_layers = 4):
71
+ super().__init__()
72
+ for name, parameter in backbone.named_parameters():
73
+ if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
74
+ parameter.requires_grad_(False)
75
+ if return_interm_layers:
76
+ if n_layers == 4:
77
+ return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
78
+ self.strides = [4, 8, 16, 32]
79
+ self.num_channels = [256, 512, 1024, 2048]
80
+ elif n_layers == 3:
81
+ return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
82
+ self.strides = [8, 16, 32]
83
+ self.num_channels = [512, 1024, 2048]
84
+ else:
85
+ raise ValueError("n_layers should be 3 or 4")
86
+
87
+ else:
88
+ return_layers = {'layer4': "0"}
89
+ self.strides = [32]
90
+ self.num_channels = [2048]
91
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
92
+
93
+ def forward(self, tensor_list: NestedTensor):
94
+ xs = self.body(tensor_list.tensors)
95
+ out: Dict[str, NestedTensor] = {}
96
+ for name, x in xs.items():
97
+ m = tensor_list.mask
98
+ assert m is not None
99
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
100
+ out[name] = NestedTensor(x, mask)
101
+ return out
102
+
103
+
104
+ class Backbone(BackboneBase):
105
+ """ResNet backbone with frozen BatchNorm."""
106
+ def __init__(self, name: str,
107
+ train_backbone: bool,
108
+ return_interm_layers: bool,
109
+ dilation: bool,
110
+ n_layers = 4):
111
+ norm_layer = FrozenBatchNorm2d
112
+ backbone = getattr(torchvision.models, name)(
113
+ replace_stride_with_dilation=[False, False, dilation],
114
+ weights='ResNet50_Weights.IMAGENET1K_V1', norm_layer=norm_layer)
115
+ assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded"
116
+ super().__init__(backbone, train_backbone, return_interm_layers, n_layers)
117
+ if dilation:
118
+ self.strides[-1] = self.strides[-1] // 2
119
+
120
+ class Joiner(nn.Sequential):
121
+ def __init__(self, backbone, position_embedding):
122
+ super().__init__(backbone, position_embedding)
123
+ self.strides = backbone.strides
124
+ self.num_channels = backbone.num_channels
125
+
126
+ def forward(self, tensor_list: NestedTensor):
127
+ xs = self[0](tensor_list)
128
+ out: List[NestedTensor] = []
129
+ pos = []
130
+ for name, x in sorted(xs.items()):
131
+ out.append(x)
132
+
133
+ # position encoding
134
+ for x in out:
135
+ pos.append(self[1](x).to(x.tensors.dtype))
136
+
137
+ return out, pos
138
+
139
+
140
+ def build_backbone(config):
141
+ position_embedding = build_position_encoding(config)
142
+ train_backbone = config['lr_backbone'] > 0
143
+ return_interm_layers = True
144
+ n_layers = config['num_feature_levels'] - 1
145
+ backbone = Backbone('resnet50', train_backbone, return_interm_layers, False, n_layers=n_layers)
146
+ model = Joiner(backbone, position_embedding)
147
+ return model
imcui/third_party/rdd/RDD/models/deformable_transformer.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Deformable DETR
2
+ # https://github.com/fundamentalvision/Deformable-DETR
3
+ # ------------------------------------------------------------------------
4
+ # Deformable DETR
5
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
6
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
7
+ # ------------------------------------------------------------------------
8
+ # Modified from DETR (https://github.com/facebookresearch/detr)
9
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
10
+ # ------------------------------------------------------------------------
11
+
12
+ import copy
13
+ from typing import Optional, List
14
+ import math
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn, Tensor
19
+ from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
20
+
21
+ from ..utils.misc import inverse_sigmoid
22
+ from .ops.modules import MSDeformAttn
23
+
24
+ class MLP(nn.Module):
25
+ """ Very simple multi-layer perceptron (also called FFN)"""
26
+
27
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
28
+ super().__init__()
29
+ self.num_layers = num_layers
30
+ h = [hidden_dim] * (num_layers - 1)
31
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
32
+
33
+ def forward(self, x):
34
+ for i, layer in enumerate(self.layers):
35
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
36
+ return x
37
+
38
+ def _get_clones(module, N):
39
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
40
+
41
+ class DeformableTransformerEncoderLayer(nn.Module):
42
+ def __init__(self,
43
+ d_model=256, d_ffn=1024,
44
+ dropout=0.1, activation="relu",
45
+ n_levels=4, n_heads=8, n_points=4):
46
+ super().__init__()
47
+
48
+ # self attention
49
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
50
+ self.dropout1 = nn.Dropout(dropout)
51
+ self.norm1 = nn.LayerNorm(d_model)
52
+
53
+ # ffn
54
+ self.linear1 = nn.Linear(d_model, d_ffn)
55
+ self.activation = _get_activation_fn(activation)
56
+ self.dropout2 = nn.Dropout(dropout)
57
+ self.linear2 = nn.Linear(d_ffn, d_model)
58
+ self.dropout3 = nn.Dropout(dropout)
59
+ self.norm2 = nn.LayerNorm(d_model)
60
+
61
+ @staticmethod
62
+ def with_pos_embed(tensor, pos):
63
+ return tensor if pos is None else tensor + pos
64
+
65
+ def forward_ffn(self, src):
66
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
67
+ src = src + self.dropout3(src2)
68
+ src = self.norm2(src)
69
+ return src
70
+
71
+ def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None, debug=False):
72
+ # self attention
73
+ if debug:
74
+ src2, sampled_points = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
75
+ else:
76
+ src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
77
+ src = src + self.dropout1(src2)
78
+ src = self.norm1(src)
79
+
80
+ # ffn
81
+ src = self.forward_ffn(src)
82
+ if debug:
83
+ return src, sampled_points
84
+ return src
85
+
86
+ class DeformableTransformerEncoder(nn.Module):
87
+ def __init__(self, encoder_layer, num_layers):
88
+ super().__init__()
89
+ self.layers = _get_clones(encoder_layer, num_layers)
90
+ self.num_layers = num_layers
91
+
92
+ @staticmethod
93
+ def get_reference_points(spatial_shapes, valid_ratios, device):
94
+ reference_points_list = []
95
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
96
+
97
+ ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
98
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
99
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
100
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
101
+ ref = torch.stack((ref_x, ref_y), -1)
102
+ reference_points_list.append(ref)
103
+ reference_points = torch.cat(reference_points_list, 1)
104
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
105
+ return reference_points
106
+
107
+ def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None, debug=False):
108
+ output = src
109
+ reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
110
+ for _, layer in enumerate(self.layers):
111
+ if debug:
112
+ output, sampled_points = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask, debug=debug)
113
+ else:
114
+ output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
115
+ if debug:
116
+ return output, reference_points, sampled_points
117
+ return output
118
+
119
+ class DecoderLayer(nn.Module):
120
+ def __init__(self, d_model=256, n_head=8, dropout=0.1):
121
+ super().__init__()
122
+ self.nhead = n_head
123
+ self.dim = d_model // n_head
124
+ self.attention = LinearAttention()
125
+ self.dropout1 = nn.Dropout(dropout)
126
+ self.norm1 = nn.LayerNorm(d_model)
127
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
128
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
129
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
130
+
131
+ self.mlp = nn.Sequential(
132
+ nn.Linear(d_model*2, d_model*2, bias=False),
133
+ nn.ReLU(True),
134
+ nn.Linear(d_model*2, d_model, bias=False),
135
+ )
136
+
137
+ self.norm2 = nn.LayerNorm(d_model)
138
+
139
+ def forward(self, tgt, src, tgt_mask=None, src_mask=None):
140
+
141
+ bs = tgt.size(0)
142
+ query, key, value = tgt, src, src
143
+
144
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
145
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
146
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
147
+
148
+ tgt2 = self.attention(query, key, value, q_mask=tgt_mask, kv_mask=src_mask)
149
+ tgt2 = tgt2.view(bs, -1, self.nhead*self.dim)
150
+ tgt2 = self.norm1(self.dropout1(tgt2))
151
+ tgt2 = self.mlp(torch.cat([tgt, tgt2], dim=2))
152
+
153
+ tgt2 = self.norm2(tgt2)
154
+
155
+ return tgt + tgt2
156
+
157
+ class MLP(nn.Module):
158
+ """ Very simple multi-layer perceptron (also called FFN)"""
159
+
160
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
161
+ super().__init__()
162
+ self.num_layers = num_layers
163
+ h = [hidden_dim] * (num_layers - 1)
164
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
165
+
166
+ def forward(self, x):
167
+ for i, layer in enumerate(self.layers):
168
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
169
+ return x
170
+
171
+
172
+ class Decoder(nn.Module):
173
+ def __init__(self, decoder_layer, num_layers):
174
+ super().__init__()
175
+ self.layers = _get_clones(decoder_layer, num_layers)
176
+ self.num_layers = num_layers
177
+
178
+ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
179
+ for layer in self.layers:
180
+ tgt = layer(tgt, memory, tgt_mask=tgt_mask, src_mask=memory_mask)
181
+
182
+ return tgt
183
+
184
+ import math
185
+
186
+ class DeformableTransformer(nn.Module):
187
+ def __init__(self, d_model=256, nhead=8,
188
+ num_encoder_layers=4, dim_feedforward=1024, dropout=0.1,
189
+ activation="relu",
190
+ num_feature_levels=5, enc_n_points=8):
191
+ super().__init__()
192
+
193
+ self.d_model = d_model
194
+ self.nhead = nhead
195
+
196
+ # Encoder and Decoder
197
+ encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
198
+ dropout, activation,
199
+ num_feature_levels, nhead, enc_n_points)
200
+ self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
201
+
202
+ # Embedding for feature levels (multi-scale)
203
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
204
+
205
+ self._reset_parameters()
206
+
207
+ def _reset_parameters(self):
208
+ for p in self.parameters():
209
+ if p.dim() > 1:
210
+ nn.init.xavier_uniform_(p)
211
+ normal_(self.level_embed)
212
+
213
+ def get_valid_ratio(self, mask):
214
+ _, H, W = mask.shape
215
+ valid_H = torch.sum(~mask[:, :, 0], 1)
216
+ valid_W = torch.sum(~mask[:, 0, :], 1)
217
+ valid_ratio_h = valid_H.float() / H
218
+ valid_ratio_w = valid_W.float() / W
219
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
220
+ return valid_ratio
221
+
222
+ def forward(self, srcs, masks, pos_embeds):
223
+
224
+ # Prepare inputs for encoder
225
+ src_flatten = []
226
+ mask_flatten = []
227
+ lvl_pos_embed_flatten = []
228
+ spatial_shapes = []
229
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
230
+ bs, c, h, w = src.shape
231
+ spatial_shape = (h, w)
232
+ spatial_shapes.append(spatial_shape)
233
+ src = src.flatten(2).transpose(1, 2)
234
+ mask = mask.flatten(1)
235
+ pos_embed = pos_embed.flatten(2).transpose(1, 2).to(src.device)
236
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
237
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
238
+ src_flatten.append(src)
239
+ mask_flatten.append(mask)
240
+ src_flatten = torch.cat(src_flatten, 1)
241
+ mask_flatten = torch.cat(mask_flatten, 1)
242
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
243
+ spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
244
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
245
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
246
+
247
+ memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten, debug=False)
248
+
249
+ return memory, spatial_shapes, level_start_index
250
+
251
+ def _get_clones(module, N):
252
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
253
+
254
+
255
+ def _get_activation_fn(activation):
256
+ """Return an activation function given a string"""
257
+ if activation == "relu":
258
+ return F.relu
259
+ if activation == "gelu":
260
+ return F.gelu
261
+ if activation == "glu":
262
+ return F.glu
263
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
264
+
265
+
266
+ def build_deforamble_transformer(config):
267
+ return DeformableTransformer(d_model=config['d_model'], nhead=config['nhead'],
268
+ num_encoder_layers=config['num_encoder_layers'], dim_feedforward=config['dim_feedforward'], dropout=config['dropout'],
269
+ activation=config['activation'],
270
+ num_feature_levels=config['num_feature_levels'], enc_n_points=config['enc_n_points'])
imcui/third_party/rdd/RDD/models/descriptor.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from ..utils.misc import NestedTensor, nested_tensor_from_tensor_list
5
+ import torchvision.transforms as transforms
6
+ from .backbone import build_backbone
7
+ from .deformable_transformer import build_deforamble_transformer
8
+
9
+ class BasicLayer(nn.Module):
10
+ """
11
+ Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
12
+ """
13
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False):
14
+ super().__init__()
15
+ self.layer = nn.Sequential(
16
+ nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
17
+ nn.BatchNorm2d(out_channels, affine=False),
18
+ nn.ReLU(inplace = False),
19
+ )
20
+
21
+ def forward(self, x):
22
+ return self.layer(x)
23
+
24
+ class RDD_Descriptor(nn.Module):
25
+ def __init__(self, backbone, transformer, num_feature_levels):
26
+ super().__init__()
27
+ self.transformer = transformer
28
+ self.hidden_dim = transformer.d_model
29
+ self.num_feature_levels = num_feature_levels
30
+
31
+ self.matchibility_head = nn.Sequential(
32
+ BasicLayer(256, 128, 1, padding=0),
33
+ BasicLayer(128, 64, 1, padding=0),
34
+ nn.Conv2d (64, 1, 1),
35
+ nn.Sigmoid()
36
+ )
37
+
38
+ if num_feature_levels > 1:
39
+ num_backbone_outs = len(backbone.strides)
40
+ input_proj_list = []
41
+ for _ in range(num_backbone_outs):
42
+ in_channels = backbone.num_channels[_]
43
+ input_proj_list.append(nn.Sequential(
44
+ nn.Conv2d(in_channels, self.hidden_dim, kernel_size=1),
45
+ nn.GroupNorm(32, self.hidden_dim),
46
+ ))
47
+ for _ in range(num_feature_levels - num_backbone_outs):
48
+ input_proj_list.append(nn.Sequential(
49
+ nn.Conv2d(in_channels, self.hidden_dim, kernel_size=3, stride=2, padding=1),
50
+ nn.GroupNorm(32, self.hidden_dim),
51
+ ))
52
+ in_channels = self.hidden_dim
53
+ self.input_proj = nn.ModuleList(input_proj_list)
54
+ else:
55
+ self.input_proj = nn.ModuleList([
56
+ nn.Sequential(
57
+ nn.Conv2d(backbone.num_channels[0], self.hidden_dim, kernel_size=1),
58
+ nn.GroupNorm(32, self.hidden_dim),
59
+ )])
60
+ self.backbone = backbone
61
+ self.stride = backbone.strides[0]
62
+ for proj in self.input_proj:
63
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
64
+ nn.init.constant_(proj[0].bias, 0)
65
+
66
+ def forward(self, samples: NestedTensor):
67
+
68
+ if not isinstance(samples, NestedTensor):
69
+ samples = nested_tensor_from_tensor_list(samples)
70
+
71
+ features, pos = self.backbone(samples)
72
+
73
+ srcs = []
74
+ masks = []
75
+ for l, feat in enumerate(features):
76
+ src, mask = feat.decompose()
77
+ srcs.append(self.input_proj[l](src))
78
+ masks.append(mask)
79
+ assert mask is not None
80
+ if self.num_feature_levels > len(srcs):
81
+ _len_srcs = len(srcs)
82
+ for l in range(_len_srcs, self.num_feature_levels):
83
+ if l == _len_srcs:
84
+ src = self.input_proj[l](features[-1].tensors)
85
+ else:
86
+ src = self.input_proj[l](srcs[-1])
87
+ m = samples.mask
88
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
89
+ pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
90
+ srcs.append(src)
91
+ masks.append(mask)
92
+ pos.append(pos_l)
93
+
94
+ flatten_feats, spatial_shapes, level_start_index = self.transformer(srcs, masks, pos)
95
+ # Reshape the flattened features back to the original spatial shapes
96
+ feats = []
97
+ level_start_index = torch.cat((level_start_index, torch.tensor([flatten_feats.shape[1]+1]).to(level_start_index.device)))
98
+ for i, shape in enumerate(spatial_shapes):
99
+ assert len(shape) == 2
100
+ temp = flatten_feats[:, level_start_index[i] : level_start_index[i+1], :]
101
+ feats.append(temp.transpose(1, 2).view(-1, self.hidden_dim, *shape))
102
+
103
+ # Sum up the features from different levels
104
+ final_feature = feats[0]
105
+ for feat in feats[1:]:
106
+ final_feature = final_feature + F.interpolate(feat, size=final_feature.shape[-2:], mode='bilinear', align_corners=True)
107
+
108
+ matchibility = self.matchibility_head(final_feature)
109
+
110
+ return final_feature, matchibility
111
+
112
+
113
+ def build_descriptor(config):
114
+ backbone = build_backbone(config)
115
+ transformer = build_deforamble_transformer(config)
116
+ return RDD_Descriptor(backbone, transformer, config['num_feature_levels'])
imcui/third_party/rdd/RDD/models/detector.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision.models import resnet
5
+ from typing import Optional, Callable
6
+ from ..utils.misc import NestedTensor
7
+
8
+ class ConvBlock(nn.Module):
9
+ def __init__(self, in_channels, out_channels,
10
+ gate: Optional[Callable[..., nn.Module]] = None,
11
+ norm_layer: Optional[Callable[..., nn.Module]] = None):
12
+ super().__init__()
13
+ if gate is None:
14
+ self.gate = nn.ReLU(inplace=False)
15
+ else:
16
+ self.gate = gate
17
+ if norm_layer is None:
18
+ norm_layer = nn.BatchNorm2d
19
+ self.conv1 = resnet.conv3x3(in_channels, out_channels)
20
+ self.bn1 = norm_layer(out_channels)
21
+ self.conv2 = resnet.conv3x3(out_channels, out_channels)
22
+ self.bn2 = norm_layer(out_channels)
23
+
24
+ def forward(self, x):
25
+ x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W
26
+ x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W
27
+ return x
28
+
29
+ class ResBlock(nn.Module):
30
+ expansion: int = 1
31
+
32
+ def __init__(
33
+ self,
34
+ inplanes: int,
35
+ planes: int,
36
+ stride: int = 1,
37
+ downsample: Optional[nn.Module] = None,
38
+ groups: int = 1,
39
+ base_width: int = 64,
40
+ dilation: int = 1,
41
+ gate: Optional[Callable[..., nn.Module]] = None,
42
+ norm_layer: Optional[Callable[..., nn.Module]] = None
43
+ ) -> None:
44
+ super(ResBlock, self).__init__()
45
+ if gate is None:
46
+ self.gate = nn.ReLU(inplace=False)
47
+ else:
48
+ self.gate = gate
49
+ if norm_layer is None:
50
+ norm_layer = nn.BatchNorm2d
51
+ if groups != 1 or base_width != 64:
52
+ raise ValueError('ResBlock only supports groups=1 and base_width=64')
53
+ if dilation > 1:
54
+ raise NotImplementedError("Dilation > 1 not supported in ResBlock")
55
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
56
+ self.conv1 = resnet.conv3x3(inplanes, planes, stride)
57
+ self.bn1 = norm_layer(planes)
58
+ self.conv2 = resnet.conv3x3(planes, planes)
59
+ self.bn2 = norm_layer(planes)
60
+ self.downsample = downsample
61
+ self.stride = stride
62
+
63
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
64
+ identity = x
65
+
66
+ out = self.conv1(x)
67
+ out = self.bn1(out)
68
+ out = self.gate(out)
69
+
70
+ out = self.conv2(out)
71
+ out = self.bn2(out)
72
+
73
+ if self.downsample is not None:
74
+ identity = self.downsample(x)
75
+
76
+ out = out + identity
77
+ out = self.gate(out)
78
+
79
+ return out
80
+
81
+ class RDD_detector(nn.Module):
82
+ def __init__(self, block_dims, hidden_dim=128):
83
+ super().__init__()
84
+ self.gate = nn.ReLU(inplace=False)
85
+ self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
86
+ self.pool4 = nn.MaxPool2d(kernel_size=4, stride=4)
87
+ self.block1 = ConvBlock(3, block_dims[0], self.gate, nn.BatchNorm2d)
88
+ self.block2 = ResBlock(inplanes=block_dims[0], planes=block_dims[1], stride=1,
89
+ downsample=nn.Conv2d(block_dims[0], block_dims[1], 1),
90
+ gate=self.gate,
91
+ norm_layer=nn.BatchNorm2d)
92
+ self.block3 = ResBlock(inplanes=block_dims[1], planes=block_dims[2], stride=1,
93
+ downsample=nn.Conv2d(block_dims[1], block_dims[2], 1),
94
+ gate=self.gate,
95
+ norm_layer=nn.BatchNorm2d)
96
+ self.block4 = ResBlock(inplanes=block_dims[2], planes=block_dims[3], stride=1,
97
+ downsample=nn.Conv2d(block_dims[2], block_dims[3], 1),
98
+ gate=self.gate,
99
+ norm_layer=nn.BatchNorm2d)
100
+
101
+ self.conv1 = resnet.conv1x1(block_dims[0], hidden_dim // 4)
102
+ self.conv2 = resnet.conv1x1(block_dims[1], hidden_dim // 4)
103
+ self.conv3 = resnet.conv1x1(block_dims[2], hidden_dim // 4)
104
+ self.conv4 = resnet.conv1x1(block_dims[3], hidden_dim // 4)
105
+
106
+ self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
107
+ self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
108
+ self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
109
+ self.upsample32 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True)
110
+
111
+ self.convhead2 = nn.Sequential(
112
+ resnet.conv1x1(hidden_dim, 1),
113
+ nn.Sigmoid()
114
+ )
115
+
116
+ def forward(self, samples: NestedTensor):
117
+ x1 = self.block1(samples.tensors)
118
+ x2 = self.pool2(x1)
119
+ x2 = self.block2(x2) # B x c2 x H/2 x W/2
120
+ x3 = self.pool4(x2)
121
+ x3 = self.block3(x3) # B x c3 x H/8 x W/8
122
+ x4 = self.pool4(x3)
123
+ x4 = self.block4(x4)
124
+
125
+ x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W
126
+ x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2
127
+ x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8
128
+ x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32
129
+
130
+ x2_up = self.upsample2(x2)
131
+ x3_up = self.upsample8(x3)
132
+ x4_up = self.upsample32(x4)
133
+
134
+ x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1)
135
+ scoremap = self.convhead2(x1234)
136
+
137
+ return scoremap
138
+
139
+ def build_detector(config):
140
+ block_dims = config['block_dims']
141
+ return RDD_detector(block_dims, block_dims[-1])
imcui/third_party/rdd/RDD/models/interpolator.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
3
+ https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ class InterpolateSparse2d(nn.Module):
11
+ """ Efficiently interpolate tensor at given sparse 2D positions. """
12
+ def __init__(self, mode = 'bilinear', align_corners = True):
13
+ super().__init__()
14
+ self.mode = mode
15
+ self.align_corners = align_corners
16
+
17
+ def normgrid(self, x, H, W):
18
+ """ Normalize coords to [-1,1]. """
19
+ return 2. * (x/(torch.tensor([W-1, H-1], device = x.device, dtype = x.dtype))) - 1.
20
+
21
+ def forward(self, x, pos, H, W):
22
+ """
23
+ Input
24
+ x: [B, C, H, W] feature tensor
25
+ pos: [B, N, 2] tensor of positions
26
+ H, W: int, original resolution of input 2d positions -- used in normalization [-1,1]
27
+
28
+ Returns
29
+ [B, N, C] sampled channels at 2d positions
30
+ """
31
+ grid = self.normgrid(pos, H, W).unsqueeze(-2).to(x.dtype)
32
+ x = F.grid_sample(x, grid, mode = self.mode , align_corners = self.align_corners)
33
+ return x.permute(0,2,3,1).squeeze(-2)
imcui/third_party/rdd/RDD/models/ops/functions/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from .ms_deform_attn_func import MSDeformAttnFunction
13
+
imcui/third_party/rdd/RDD/models/ops/functions/ms_deform_attn_func.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import print_function
14
+ from __future__ import division
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch.autograd import Function
19
+ from torch.autograd.function import once_differentiable
20
+
21
+ try:
22
+ import MultiScaleDeformableAttention as MSDA
23
+ except ModuleNotFoundError as e:
24
+ info_string = (
25
+ "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n"
26
+ "\t`cd mask2former/modeling/pixel_decoder/ops`\n"
27
+ "\t`sh make.sh`\n"
28
+ )
29
+ print(info_string)
30
+
31
+
32
+ class MSDeformAttnFunction(Function):
33
+ @staticmethod
34
+ def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
35
+ ctx.im2col_step = im2col_step
36
+ output = MSDA.ms_deform_attn_forward(
37
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
38
+ ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
39
+ return output
40
+
41
+ @staticmethod
42
+ @once_differentiable
43
+ def backward(ctx, grad_output):
44
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
45
+ grad_value, grad_sampling_loc, grad_attn_weight = \
46
+ MSDA.ms_deform_attn_backward(
47
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
48
+
49
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
50
+
51
+
52
+ def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
53
+ # for debug and test only,
54
+ # need to use cuda version instead
55
+ N_, S_, M_, D_ = value.shape
56
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
57
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
58
+ sampling_grids = 2 * sampling_locations - 1
59
+ sampling_value_list = []
60
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
61
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
62
+ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
63
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
64
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
65
+ # N_*M_, D_, Lq_, P_
66
+ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
67
+ mode='bilinear', padding_mode='zeros', align_corners=False)
68
+ sampling_value_list.append(sampling_value_l_)
69
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
70
+ attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
71
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
72
+ return output.transpose(1, 2).contiguous()
imcui/third_party/rdd/RDD/models/ops/make.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # ------------------------------------------------------------------------------------------------
3
+ # Deformable DETR
4
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------------------------------
7
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ # ------------------------------------------------------------------------------------------------
9
+
10
+ # Copyright (c) Facebook, Inc. and its affiliates.
11
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
12
+
13
+ python setup.py build install
imcui/third_party/rdd/RDD/models/ops/modules/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from .ms_deform_attn import MSDeformAttn
imcui/third_party/rdd/RDD/models/ops/modules/ms_deform_attn.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import print_function
14
+ from __future__ import division
15
+
16
+ import warnings
17
+ import math
18
+
19
+ import torch
20
+ from torch import nn
21
+ import torch.nn.functional as F
22
+ from torch.nn.init import xavier_uniform_, constant_
23
+
24
+ from ..functions import MSDeformAttnFunction
25
+ from ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch
26
+
27
+
28
+ def _is_power_of_2(n):
29
+ if (not isinstance(n, int)) or (n < 0):
30
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
31
+ return (n & (n-1) == 0) and n != 0
32
+
33
+
34
+ class MSDeformAttn(nn.Module):
35
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
36
+ """
37
+ Multi-Scale Deformable Attention Module
38
+ :param d_model hidden dimension
39
+ :param n_levels number of feature levels
40
+ :param n_heads number of attention heads
41
+ :param n_points number of sampling points per attention head per feature level
42
+ """
43
+ super().__init__()
44
+ if d_model % n_heads != 0:
45
+ raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
46
+ _d_per_head = d_model // n_heads
47
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
48
+ if not _is_power_of_2(_d_per_head):
49
+ warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
50
+ "which is more efficient in our CUDA implementation.")
51
+
52
+ self.im2col_step = 128
53
+
54
+ self.d_model = d_model
55
+ self.n_levels = n_levels
56
+ self.n_heads = n_heads
57
+ self.n_points = n_points
58
+
59
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
60
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
61
+ self.value_proj = nn.Linear(d_model, d_model)
62
+ self.output_proj = nn.Linear(d_model, d_model)
63
+
64
+ self._reset_parameters()
65
+
66
+ def _reset_parameters(self):
67
+ constant_(self.sampling_offsets.weight.data, 0.)
68
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
69
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
70
+ grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
71
+ for i in range(self.n_points):
72
+ grid_init[:, :, i, :] *= i + 1
73
+ with torch.no_grad():
74
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
75
+ constant_(self.attention_weights.weight.data, 0.)
76
+ constant_(self.attention_weights.bias.data, 0.)
77
+ xavier_uniform_(self.value_proj.weight.data)
78
+ constant_(self.value_proj.bias.data, 0.)
79
+ xavier_uniform_(self.output_proj.weight.data)
80
+ constant_(self.output_proj.bias.data, 0.)
81
+
82
+ def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
83
+ """
84
+ :param query (N, Length_{query}, C)
85
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
86
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
87
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
88
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
89
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
90
+ :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
91
+
92
+ :return output (N, Length_{query}, C)
93
+ """
94
+ N, Len_q, _ = query.shape
95
+ N, Len_in, _ = input_flatten.shape
96
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
97
+
98
+ value = self.value_proj(input_flatten)
99
+ if input_padding_mask is not None:
100
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
101
+ value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
102
+ sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
103
+ attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
104
+ attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
105
+ # N, Len_q, n_heads, n_levels, n_points, 2
106
+ if reference_points.shape[-1] == 2:
107
+ offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
108
+ sampling_locations = reference_points[:, :, None, :, None, :] \
109
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
110
+ elif reference_points.shape[-1] == 4:
111
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
112
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
113
+ else:
114
+ raise ValueError(
115
+ 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
116
+ try:
117
+ output = MSDeformAttnFunction.apply(
118
+ value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
119
+ except:
120
+ # CPU
121
+ output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
122
+ # # For FLOPs calculation only
123
+ # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
124
+ output = self.output_proj(output)
125
+ return output
imcui/third_party/rdd/RDD/models/ops/setup.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ import os
13
+ import glob
14
+
15
+ import torch
16
+
17
+ from torch.utils.cpp_extension import CUDA_HOME
18
+ from torch.utils.cpp_extension import CppExtension
19
+ from torch.utils.cpp_extension import CUDAExtension
20
+
21
+ from setuptools import find_packages
22
+ from setuptools import setup
23
+
24
+ requirements = ["torch", "torchvision"]
25
+
26
+ def get_extensions():
27
+ this_dir = os.path.dirname(os.path.abspath(__file__))
28
+ extensions_dir = os.path.join(this_dir, "src")
29
+
30
+ main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
31
+ source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
32
+ source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
33
+
34
+ sources = main_file + source_cpu
35
+ extension = CppExtension
36
+ extra_compile_args = {"cxx": []}
37
+ define_macros = []
38
+
39
+ # Force cuda since torch ask for a device, not if cuda is in fact available.
40
+ if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None:
41
+ extension = CUDAExtension
42
+ sources += source_cuda
43
+ define_macros += [("WITH_CUDA", None)]
44
+ extra_compile_args["nvcc"] = [
45
+ "-DCUDA_HAS_FP16=1",
46
+ "-D__CUDA_NO_HALF_OPERATORS__",
47
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
48
+ "-D__CUDA_NO_HALF2_OPERATORS__",
49
+ ]
50
+ else:
51
+ if CUDA_HOME is None:
52
+ raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.')
53
+ else:
54
+ raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().')
55
+
56
+ sources = [os.path.join(extensions_dir, s) for s in sources]
57
+ include_dirs = [extensions_dir]
58
+ ext_modules = [
59
+ extension(
60
+ "MultiScaleDeformableAttention",
61
+ sources,
62
+ include_dirs=include_dirs,
63
+ define_macros=define_macros,
64
+ extra_compile_args=extra_compile_args,
65
+ )
66
+ ]
67
+ return ext_modules
68
+
69
+ setup(
70
+ name="MultiScaleDeformableAttention",
71
+ version="1.0",
72
+ author="Weijie Su",
73
+ url="https://github.com/fundamentalvision/Deformable-DETR",
74
+ description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
75
+ packages=find_packages(exclude=("configs", "tests",)),
76
+ ext_modules=get_extensions(),
77
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
78
+ )
imcui/third_party/rdd/RDD/models/ops/src/cpu/ms_deform_attn_cpu.cpp ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #include <vector>
17
+
18
+ #include <ATen/ATen.h>
19
+ #include <ATen/cuda/CUDAContext.h>
20
+
21
+
22
+ at::Tensor
23
+ ms_deform_attn_cpu_forward(
24
+ const at::Tensor &value,
25
+ const at::Tensor &spatial_shapes,
26
+ const at::Tensor &level_start_index,
27
+ const at::Tensor &sampling_loc,
28
+ const at::Tensor &attn_weight,
29
+ const int im2col_step)
30
+ {
31
+ AT_ERROR("Not implement on cpu");
32
+ }
33
+
34
+ std::vector<at::Tensor>
35
+ ms_deform_attn_cpu_backward(
36
+ const at::Tensor &value,
37
+ const at::Tensor &spatial_shapes,
38
+ const at::Tensor &level_start_index,
39
+ const at::Tensor &sampling_loc,
40
+ const at::Tensor &attn_weight,
41
+ const at::Tensor &grad_output,
42
+ const int im2col_step)
43
+ {
44
+ AT_ERROR("Not implement on cpu");
45
+ }
46
+
imcui/third_party/rdd/RDD/models/ops/src/cpu/ms_deform_attn_cpu.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #pragma once
17
+ #include <torch/extension.h>
18
+
19
+ at::Tensor
20
+ ms_deform_attn_cpu_forward(
21
+ const at::Tensor &value,
22
+ const at::Tensor &spatial_shapes,
23
+ const at::Tensor &level_start_index,
24
+ const at::Tensor &sampling_loc,
25
+ const at::Tensor &attn_weight,
26
+ const int im2col_step);
27
+
28
+ std::vector<at::Tensor>
29
+ ms_deform_attn_cpu_backward(
30
+ const at::Tensor &value,
31
+ const at::Tensor &spatial_shapes,
32
+ const at::Tensor &level_start_index,
33
+ const at::Tensor &sampling_loc,
34
+ const at::Tensor &attn_weight,
35
+ const at::Tensor &grad_output,
36
+ const int im2col_step);
37
+
38
+
imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_attn_cuda.cu ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #include <vector>
17
+ #include "cuda/ms_deform_im2col_cuda.cuh"
18
+
19
+ #include <ATen/ATen.h>
20
+ #include <ATen/cuda/CUDAContext.h>
21
+ #include <cuda.h>
22
+ #include <cuda_runtime.h>
23
+
24
+
25
+ at::Tensor ms_deform_attn_cuda_forward(
26
+ const at::Tensor &value,
27
+ const at::Tensor &spatial_shapes,
28
+ const at::Tensor &level_start_index,
29
+ const at::Tensor &sampling_loc,
30
+ const at::Tensor &attn_weight,
31
+ const int im2col_step)
32
+ {
33
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
34
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
35
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
36
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
37
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
38
+
39
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
40
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
41
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
42
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
43
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
44
+
45
+ const int batch = value.size(0);
46
+ const int spatial_size = value.size(1);
47
+ const int num_heads = value.size(2);
48
+ const int channels = value.size(3);
49
+
50
+ const int num_levels = spatial_shapes.size(0);
51
+
52
+ const int num_query = sampling_loc.size(1);
53
+ const int num_point = sampling_loc.size(4);
54
+
55
+ const int im2col_step_ = std::min(batch, im2col_step);
56
+
57
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
58
+
59
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
60
+
61
+ const int batch_n = im2col_step_;
62
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
63
+ auto per_value_size = spatial_size * num_heads * channels;
64
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
65
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
66
+ for (int n = 0; n < batch/im2col_step_; ++n)
67
+ {
68
+ auto columns = output_n.select(0, n);
69
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
70
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
71
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
72
+ spatial_shapes.data<int64_t>(),
73
+ level_start_index.data<int64_t>(),
74
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
75
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
76
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
77
+ columns.data<scalar_t>());
78
+
79
+ }));
80
+ }
81
+
82
+ output = output.view({batch, num_query, num_heads*channels});
83
+
84
+ return output;
85
+ }
86
+
87
+
88
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
89
+ const at::Tensor &value,
90
+ const at::Tensor &spatial_shapes,
91
+ const at::Tensor &level_start_index,
92
+ const at::Tensor &sampling_loc,
93
+ const at::Tensor &attn_weight,
94
+ const at::Tensor &grad_output,
95
+ const int im2col_step)
96
+ {
97
+
98
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
99
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
100
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
101
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
102
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
103
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
104
+
105
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
106
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
107
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
108
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
109
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
110
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
111
+
112
+ const int batch = value.size(0);
113
+ const int spatial_size = value.size(1);
114
+ const int num_heads = value.size(2);
115
+ const int channels = value.size(3);
116
+
117
+ const int num_levels = spatial_shapes.size(0);
118
+
119
+ const int num_query = sampling_loc.size(1);
120
+ const int num_point = sampling_loc.size(4);
121
+
122
+ const int im2col_step_ = std::min(batch, im2col_step);
123
+
124
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
125
+
126
+ auto grad_value = at::zeros_like(value);
127
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
128
+ auto grad_attn_weight = at::zeros_like(attn_weight);
129
+
130
+ const int batch_n = im2col_step_;
131
+ auto per_value_size = spatial_size * num_heads * channels;
132
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
133
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
134
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
135
+
136
+ for (int n = 0; n < batch/im2col_step_; ++n)
137
+ {
138
+ auto grad_output_g = grad_output_n.select(0, n);
139
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
140
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
141
+ grad_output_g.data<scalar_t>(),
142
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
143
+ spatial_shapes.data<int64_t>(),
144
+ level_start_index.data<int64_t>(),
145
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
146
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
147
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
148
+ grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
149
+ grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
150
+ grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
151
+
152
+ }));
153
+ }
154
+
155
+ return {
156
+ grad_value, grad_sampling_loc, grad_attn_weight
157
+ };
158
+ }
imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_attn_cuda.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #pragma once
17
+ #include <torch/extension.h>
18
+
19
+ at::Tensor ms_deform_attn_cuda_forward(
20
+ const at::Tensor &value,
21
+ const at::Tensor &spatial_shapes,
22
+ const at::Tensor &level_start_index,
23
+ const at::Tensor &sampling_loc,
24
+ const at::Tensor &attn_weight,
25
+ const int im2col_step);
26
+
27
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
28
+ const at::Tensor &value,
29
+ const at::Tensor &spatial_shapes,
30
+ const at::Tensor &level_start_index,
31
+ const at::Tensor &sampling_loc,
32
+ const at::Tensor &attn_weight,
33
+ const at::Tensor &grad_output,
34
+ const int im2col_step);
35
+
imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_im2col_cuda.cuh ADDED
@@ -0,0 +1,1332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************
7
+ * Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
8
+ * Copyright (c) 2018 Microsoft
9
+ **************************************************************************
10
+ */
11
+
12
+ /*!
13
+ * Copyright (c) Facebook, Inc. and its affiliates.
14
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
15
+ */
16
+
17
+ #include <cstdio>
18
+ #include <algorithm>
19
+ #include <cstring>
20
+
21
+ #include <ATen/ATen.h>
22
+ #include <ATen/cuda/CUDAContext.h>
23
+
24
+ #include <THC/THCAtomics.cuh>
25
+
26
+ #define CUDA_KERNEL_LOOP(i, n) \
27
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
28
+ i < (n); \
29
+ i += blockDim.x * gridDim.x)
30
+
31
+ const int CUDA_NUM_THREADS = 1024;
32
+ inline int GET_BLOCKS(const int N, const int num_threads)
33
+ {
34
+ return (N + num_threads - 1) / num_threads;
35
+ }
36
+
37
+
38
+ template <typename scalar_t>
39
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
40
+ const int &height, const int &width, const int &nheads, const int &channels,
41
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
42
+ {
43
+ const int h_low = floor(h);
44
+ const int w_low = floor(w);
45
+ const int h_high = h_low + 1;
46
+ const int w_high = w_low + 1;
47
+
48
+ const scalar_t lh = h - h_low;
49
+ const scalar_t lw = w - w_low;
50
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
51
+
52
+ const int w_stride = nheads * channels;
53
+ const int h_stride = width * w_stride;
54
+ const int h_low_ptr_offset = h_low * h_stride;
55
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
56
+ const int w_low_ptr_offset = w_low * w_stride;
57
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
58
+ const int base_ptr = m * channels + c;
59
+
60
+ scalar_t v1 = 0;
61
+ if (h_low >= 0 && w_low >= 0)
62
+ {
63
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
64
+ v1 = bottom_data[ptr1];
65
+ }
66
+ scalar_t v2 = 0;
67
+ if (h_low >= 0 && w_high <= width - 1)
68
+ {
69
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
70
+ v2 = bottom_data[ptr2];
71
+ }
72
+ scalar_t v3 = 0;
73
+ if (h_high <= height - 1 && w_low >= 0)
74
+ {
75
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
76
+ v3 = bottom_data[ptr3];
77
+ }
78
+ scalar_t v4 = 0;
79
+ if (h_high <= height - 1 && w_high <= width - 1)
80
+ {
81
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
82
+ v4 = bottom_data[ptr4];
83
+ }
84
+
85
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
86
+
87
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
88
+ return val;
89
+ }
90
+
91
+
92
+ template <typename scalar_t>
93
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
94
+ const int &height, const int &width, const int &nheads, const int &channels,
95
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
96
+ const scalar_t &top_grad,
97
+ const scalar_t &attn_weight,
98
+ scalar_t* &grad_value,
99
+ scalar_t* grad_sampling_loc,
100
+ scalar_t* grad_attn_weight)
101
+ {
102
+ const int h_low = floor(h);
103
+ const int w_low = floor(w);
104
+ const int h_high = h_low + 1;
105
+ const int w_high = w_low + 1;
106
+
107
+ const scalar_t lh = h - h_low;
108
+ const scalar_t lw = w - w_low;
109
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
110
+
111
+ const int w_stride = nheads * channels;
112
+ const int h_stride = width * w_stride;
113
+ const int h_low_ptr_offset = h_low * h_stride;
114
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
115
+ const int w_low_ptr_offset = w_low * w_stride;
116
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
117
+ const int base_ptr = m * channels + c;
118
+
119
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
120
+ const scalar_t top_grad_value = top_grad * attn_weight;
121
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
122
+
123
+ scalar_t v1 = 0;
124
+ if (h_low >= 0 && w_low >= 0)
125
+ {
126
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
127
+ v1 = bottom_data[ptr1];
128
+ grad_h_weight -= hw * v1;
129
+ grad_w_weight -= hh * v1;
130
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
131
+ }
132
+ scalar_t v2 = 0;
133
+ if (h_low >= 0 && w_high <= width - 1)
134
+ {
135
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
136
+ v2 = bottom_data[ptr2];
137
+ grad_h_weight -= lw * v2;
138
+ grad_w_weight += hh * v2;
139
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
140
+ }
141
+ scalar_t v3 = 0;
142
+ if (h_high <= height - 1 && w_low >= 0)
143
+ {
144
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
145
+ v3 = bottom_data[ptr3];
146
+ grad_h_weight += hw * v3;
147
+ grad_w_weight -= lh * v3;
148
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
149
+ }
150
+ scalar_t v4 = 0;
151
+ if (h_high <= height - 1 && w_high <= width - 1)
152
+ {
153
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
154
+ v4 = bottom_data[ptr4];
155
+ grad_h_weight += lw * v4;
156
+ grad_w_weight += lh * v4;
157
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
158
+ }
159
+
160
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
161
+ *grad_attn_weight = top_grad * val;
162
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
163
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
164
+ }
165
+
166
+
167
+ template <typename scalar_t>
168
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
169
+ const int &height, const int &width, const int &nheads, const int &channels,
170
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
171
+ const scalar_t &top_grad,
172
+ const scalar_t &attn_weight,
173
+ scalar_t* &grad_value,
174
+ scalar_t* grad_sampling_loc,
175
+ scalar_t* grad_attn_weight)
176
+ {
177
+ const int h_low = floor(h);
178
+ const int w_low = floor(w);
179
+ const int h_high = h_low + 1;
180
+ const int w_high = w_low + 1;
181
+
182
+ const scalar_t lh = h - h_low;
183
+ const scalar_t lw = w - w_low;
184
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
185
+
186
+ const int w_stride = nheads * channels;
187
+ const int h_stride = width * w_stride;
188
+ const int h_low_ptr_offset = h_low * h_stride;
189
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
190
+ const int w_low_ptr_offset = w_low * w_stride;
191
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
192
+ const int base_ptr = m * channels + c;
193
+
194
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
195
+ const scalar_t top_grad_value = top_grad * attn_weight;
196
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
197
+
198
+ scalar_t v1 = 0;
199
+ if (h_low >= 0 && w_low >= 0)
200
+ {
201
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
202
+ v1 = bottom_data[ptr1];
203
+ grad_h_weight -= hw * v1;
204
+ grad_w_weight -= hh * v1;
205
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
206
+ }
207
+ scalar_t v2 = 0;
208
+ if (h_low >= 0 && w_high <= width - 1)
209
+ {
210
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
211
+ v2 = bottom_data[ptr2];
212
+ grad_h_weight -= lw * v2;
213
+ grad_w_weight += hh * v2;
214
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
215
+ }
216
+ scalar_t v3 = 0;
217
+ if (h_high <= height - 1 && w_low >= 0)
218
+ {
219
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
220
+ v3 = bottom_data[ptr3];
221
+ grad_h_weight += hw * v3;
222
+ grad_w_weight -= lh * v3;
223
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
224
+ }
225
+ scalar_t v4 = 0;
226
+ if (h_high <= height - 1 && w_high <= width - 1)
227
+ {
228
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
229
+ v4 = bottom_data[ptr4];
230
+ grad_h_weight += lw * v4;
231
+ grad_w_weight += lh * v4;
232
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
233
+ }
234
+
235
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
236
+ atomicAdd(grad_attn_weight, top_grad * val);
237
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
238
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
239
+ }
240
+
241
+
242
+ template <typename scalar_t>
243
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
244
+ const scalar_t *data_value,
245
+ const int64_t *data_spatial_shapes,
246
+ const int64_t *data_level_start_index,
247
+ const scalar_t *data_sampling_loc,
248
+ const scalar_t *data_attn_weight,
249
+ const int batch_size,
250
+ const int spatial_size,
251
+ const int num_heads,
252
+ const int channels,
253
+ const int num_levels,
254
+ const int num_query,
255
+ const int num_point,
256
+ scalar_t *data_col)
257
+ {
258
+ CUDA_KERNEL_LOOP(index, n)
259
+ {
260
+ int _temp = index;
261
+ const int c_col = _temp % channels;
262
+ _temp /= channels;
263
+ const int sampling_index = _temp;
264
+ const int m_col = _temp % num_heads;
265
+ _temp /= num_heads;
266
+ const int q_col = _temp % num_query;
267
+ _temp /= num_query;
268
+ const int b_col = _temp;
269
+
270
+ scalar_t *data_col_ptr = data_col + index;
271
+ int data_weight_ptr = sampling_index * num_levels * num_point;
272
+ int data_loc_w_ptr = data_weight_ptr << 1;
273
+ const int qid_stride = num_heads * channels;
274
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
275
+ scalar_t col = 0;
276
+
277
+ for (int l_col=0; l_col < num_levels; ++l_col)
278
+ {
279
+ const int level_start_id = data_level_start_index[l_col];
280
+ const int spatial_h_ptr = l_col << 1;
281
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
282
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
283
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
284
+ for (int p_col=0; p_col < num_point; ++p_col)
285
+ {
286
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
287
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
288
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
289
+
290
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
291
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
292
+
293
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
294
+ {
295
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
296
+ }
297
+
298
+ data_weight_ptr += 1;
299
+ data_loc_w_ptr += 2;
300
+ }
301
+ }
302
+ *data_col_ptr = col;
303
+ }
304
+ }
305
+
306
+ template <typename scalar_t, unsigned int blockSize>
307
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
308
+ const scalar_t *grad_col,
309
+ const scalar_t *data_value,
310
+ const int64_t *data_spatial_shapes,
311
+ const int64_t *data_level_start_index,
312
+ const scalar_t *data_sampling_loc,
313
+ const scalar_t *data_attn_weight,
314
+ const int batch_size,
315
+ const int spatial_size,
316
+ const int num_heads,
317
+ const int channels,
318
+ const int num_levels,
319
+ const int num_query,
320
+ const int num_point,
321
+ scalar_t *grad_value,
322
+ scalar_t *grad_sampling_loc,
323
+ scalar_t *grad_attn_weight)
324
+ {
325
+ CUDA_KERNEL_LOOP(index, n)
326
+ {
327
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
328
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
329
+ unsigned int tid = threadIdx.x;
330
+ int _temp = index;
331
+ const int c_col = _temp % channels;
332
+ _temp /= channels;
333
+ const int sampling_index = _temp;
334
+ const int m_col = _temp % num_heads;
335
+ _temp /= num_heads;
336
+ const int q_col = _temp % num_query;
337
+ _temp /= num_query;
338
+ const int b_col = _temp;
339
+
340
+ const scalar_t top_grad = grad_col[index];
341
+
342
+ int data_weight_ptr = sampling_index * num_levels * num_point;
343
+ int data_loc_w_ptr = data_weight_ptr << 1;
344
+ const int grad_sampling_ptr = data_weight_ptr;
345
+ grad_sampling_loc += grad_sampling_ptr << 1;
346
+ grad_attn_weight += grad_sampling_ptr;
347
+ const int grad_weight_stride = 1;
348
+ const int grad_loc_stride = 2;
349
+ const int qid_stride = num_heads * channels;
350
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
351
+
352
+ for (int l_col=0; l_col < num_levels; ++l_col)
353
+ {
354
+ const int level_start_id = data_level_start_index[l_col];
355
+ const int spatial_h_ptr = l_col << 1;
356
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
357
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
358
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
359
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
360
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
361
+
362
+ for (int p_col=0; p_col < num_point; ++p_col)
363
+ {
364
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
365
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
366
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
367
+
368
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
369
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
370
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
371
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
372
+ *(cache_grad_attn_weight+threadIdx.x)=0;
373
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
374
+ {
375
+ ms_deform_attn_col2im_bilinear(
376
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
377
+ top_grad, weight, grad_value_ptr,
378
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
379
+ }
380
+
381
+ __syncthreads();
382
+ if (tid == 0)
383
+ {
384
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
385
+ int sid=2;
386
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
387
+ {
388
+ _grad_w += cache_grad_sampling_loc[sid];
389
+ _grad_h += cache_grad_sampling_loc[sid + 1];
390
+ _grad_a += cache_grad_attn_weight[tid];
391
+ sid += 2;
392
+ }
393
+
394
+
395
+ *grad_sampling_loc = _grad_w;
396
+ *(grad_sampling_loc + 1) = _grad_h;
397
+ *grad_attn_weight = _grad_a;
398
+ }
399
+ __syncthreads();
400
+
401
+ data_weight_ptr += 1;
402
+ data_loc_w_ptr += 2;
403
+ grad_attn_weight += grad_weight_stride;
404
+ grad_sampling_loc += grad_loc_stride;
405
+ }
406
+ }
407
+ }
408
+ }
409
+
410
+
411
+ template <typename scalar_t, unsigned int blockSize>
412
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
413
+ const scalar_t *grad_col,
414
+ const scalar_t *data_value,
415
+ const int64_t *data_spatial_shapes,
416
+ const int64_t *data_level_start_index,
417
+ const scalar_t *data_sampling_loc,
418
+ const scalar_t *data_attn_weight,
419
+ const int batch_size,
420
+ const int spatial_size,
421
+ const int num_heads,
422
+ const int channels,
423
+ const int num_levels,
424
+ const int num_query,
425
+ const int num_point,
426
+ scalar_t *grad_value,
427
+ scalar_t *grad_sampling_loc,
428
+ scalar_t *grad_attn_weight)
429
+ {
430
+ CUDA_KERNEL_LOOP(index, n)
431
+ {
432
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
433
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
434
+ unsigned int tid = threadIdx.x;
435
+ int _temp = index;
436
+ const int c_col = _temp % channels;
437
+ _temp /= channels;
438
+ const int sampling_index = _temp;
439
+ const int m_col = _temp % num_heads;
440
+ _temp /= num_heads;
441
+ const int q_col = _temp % num_query;
442
+ _temp /= num_query;
443
+ const int b_col = _temp;
444
+
445
+ const scalar_t top_grad = grad_col[index];
446
+
447
+ int data_weight_ptr = sampling_index * num_levels * num_point;
448
+ int data_loc_w_ptr = data_weight_ptr << 1;
449
+ const int grad_sampling_ptr = data_weight_ptr;
450
+ grad_sampling_loc += grad_sampling_ptr << 1;
451
+ grad_attn_weight += grad_sampling_ptr;
452
+ const int grad_weight_stride = 1;
453
+ const int grad_loc_stride = 2;
454
+ const int qid_stride = num_heads * channels;
455
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
456
+
457
+ for (int l_col=0; l_col < num_levels; ++l_col)
458
+ {
459
+ const int level_start_id = data_level_start_index[l_col];
460
+ const int spatial_h_ptr = l_col << 1;
461
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
462
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
463
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
464
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
465
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
466
+
467
+ for (int p_col=0; p_col < num_point; ++p_col)
468
+ {
469
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
470
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
471
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
472
+
473
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
474
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
475
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
476
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
477
+ *(cache_grad_attn_weight+threadIdx.x)=0;
478
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
479
+ {
480
+ ms_deform_attn_col2im_bilinear(
481
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
482
+ top_grad, weight, grad_value_ptr,
483
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
484
+ }
485
+
486
+ __syncthreads();
487
+
488
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
489
+ {
490
+ if (tid < s) {
491
+ const unsigned int xid1 = tid << 1;
492
+ const unsigned int xid2 = (tid + s) << 1;
493
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
494
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
495
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
496
+ }
497
+ __syncthreads();
498
+ }
499
+
500
+ if (tid == 0)
501
+ {
502
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
503
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
504
+ *grad_attn_weight = cache_grad_attn_weight[0];
505
+ }
506
+ __syncthreads();
507
+
508
+ data_weight_ptr += 1;
509
+ data_loc_w_ptr += 2;
510
+ grad_attn_weight += grad_weight_stride;
511
+ grad_sampling_loc += grad_loc_stride;
512
+ }
513
+ }
514
+ }
515
+ }
516
+
517
+
518
+ template <typename scalar_t>
519
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
520
+ const scalar_t *grad_col,
521
+ const scalar_t *data_value,
522
+ const int64_t *data_spatial_shapes,
523
+ const int64_t *data_level_start_index,
524
+ const scalar_t *data_sampling_loc,
525
+ const scalar_t *data_attn_weight,
526
+ const int batch_size,
527
+ const int spatial_size,
528
+ const int num_heads,
529
+ const int channels,
530
+ const int num_levels,
531
+ const int num_query,
532
+ const int num_point,
533
+ scalar_t *grad_value,
534
+ scalar_t *grad_sampling_loc,
535
+ scalar_t *grad_attn_weight)
536
+ {
537
+ CUDA_KERNEL_LOOP(index, n)
538
+ {
539
+ extern __shared__ int _s[];
540
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
541
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
542
+ unsigned int tid = threadIdx.x;
543
+ int _temp = index;
544
+ const int c_col = _temp % channels;
545
+ _temp /= channels;
546
+ const int sampling_index = _temp;
547
+ const int m_col = _temp % num_heads;
548
+ _temp /= num_heads;
549
+ const int q_col = _temp % num_query;
550
+ _temp /= num_query;
551
+ const int b_col = _temp;
552
+
553
+ const scalar_t top_grad = grad_col[index];
554
+
555
+ int data_weight_ptr = sampling_index * num_levels * num_point;
556
+ int data_loc_w_ptr = data_weight_ptr << 1;
557
+ const int grad_sampling_ptr = data_weight_ptr;
558
+ grad_sampling_loc += grad_sampling_ptr << 1;
559
+ grad_attn_weight += grad_sampling_ptr;
560
+ const int grad_weight_stride = 1;
561
+ const int grad_loc_stride = 2;
562
+ const int qid_stride = num_heads * channels;
563
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
564
+
565
+ for (int l_col=0; l_col < num_levels; ++l_col)
566
+ {
567
+ const int level_start_id = data_level_start_index[l_col];
568
+ const int spatial_h_ptr = l_col << 1;
569
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
570
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
571
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
572
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
573
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
574
+
575
+ for (int p_col=0; p_col < num_point; ++p_col)
576
+ {
577
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
578
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
579
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
580
+
581
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
582
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
583
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
584
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
585
+ *(cache_grad_attn_weight+threadIdx.x)=0;
586
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
587
+ {
588
+ ms_deform_attn_col2im_bilinear(
589
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
590
+ top_grad, weight, grad_value_ptr,
591
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
592
+ }
593
+
594
+ __syncthreads();
595
+ if (tid == 0)
596
+ {
597
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
598
+ int sid=2;
599
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
600
+ {
601
+ _grad_w += cache_grad_sampling_loc[sid];
602
+ _grad_h += cache_grad_sampling_loc[sid + 1];
603
+ _grad_a += cache_grad_attn_weight[tid];
604
+ sid += 2;
605
+ }
606
+
607
+
608
+ *grad_sampling_loc = _grad_w;
609
+ *(grad_sampling_loc + 1) = _grad_h;
610
+ *grad_attn_weight = _grad_a;
611
+ }
612
+ __syncthreads();
613
+
614
+ data_weight_ptr += 1;
615
+ data_loc_w_ptr += 2;
616
+ grad_attn_weight += grad_weight_stride;
617
+ grad_sampling_loc += grad_loc_stride;
618
+ }
619
+ }
620
+ }
621
+ }
622
+
623
+ template <typename scalar_t>
624
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
625
+ const scalar_t *grad_col,
626
+ const scalar_t *data_value,
627
+ const int64_t *data_spatial_shapes,
628
+ const int64_t *data_level_start_index,
629
+ const scalar_t *data_sampling_loc,
630
+ const scalar_t *data_attn_weight,
631
+ const int batch_size,
632
+ const int spatial_size,
633
+ const int num_heads,
634
+ const int channels,
635
+ const int num_levels,
636
+ const int num_query,
637
+ const int num_point,
638
+ scalar_t *grad_value,
639
+ scalar_t *grad_sampling_loc,
640
+ scalar_t *grad_attn_weight)
641
+ {
642
+ CUDA_KERNEL_LOOP(index, n)
643
+ {
644
+ extern __shared__ int _s[];
645
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
646
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
647
+ unsigned int tid = threadIdx.x;
648
+ int _temp = index;
649
+ const int c_col = _temp % channels;
650
+ _temp /= channels;
651
+ const int sampling_index = _temp;
652
+ const int m_col = _temp % num_heads;
653
+ _temp /= num_heads;
654
+ const int q_col = _temp % num_query;
655
+ _temp /= num_query;
656
+ const int b_col = _temp;
657
+
658
+ const scalar_t top_grad = grad_col[index];
659
+
660
+ int data_weight_ptr = sampling_index * num_levels * num_point;
661
+ int data_loc_w_ptr = data_weight_ptr << 1;
662
+ const int grad_sampling_ptr = data_weight_ptr;
663
+ grad_sampling_loc += grad_sampling_ptr << 1;
664
+ grad_attn_weight += grad_sampling_ptr;
665
+ const int grad_weight_stride = 1;
666
+ const int grad_loc_stride = 2;
667
+ const int qid_stride = num_heads * channels;
668
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
669
+
670
+ for (int l_col=0; l_col < num_levels; ++l_col)
671
+ {
672
+ const int level_start_id = data_level_start_index[l_col];
673
+ const int spatial_h_ptr = l_col << 1;
674
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
675
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
676
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
677
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
678
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
679
+
680
+ for (int p_col=0; p_col < num_point; ++p_col)
681
+ {
682
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
683
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
684
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
685
+
686
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
687
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
688
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
689
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
690
+ *(cache_grad_attn_weight+threadIdx.x)=0;
691
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
692
+ {
693
+ ms_deform_attn_col2im_bilinear(
694
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
695
+ top_grad, weight, grad_value_ptr,
696
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
697
+ }
698
+
699
+ __syncthreads();
700
+
701
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
702
+ {
703
+ if (tid < s) {
704
+ const unsigned int xid1 = tid << 1;
705
+ const unsigned int xid2 = (tid + s) << 1;
706
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
707
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
708
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
709
+ if (tid + (s << 1) < spre)
710
+ {
711
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
712
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
713
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
714
+ }
715
+ }
716
+ __syncthreads();
717
+ }
718
+
719
+ if (tid == 0)
720
+ {
721
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
722
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
723
+ *grad_attn_weight = cache_grad_attn_weight[0];
724
+ }
725
+ __syncthreads();
726
+
727
+ data_weight_ptr += 1;
728
+ data_loc_w_ptr += 2;
729
+ grad_attn_weight += grad_weight_stride;
730
+ grad_sampling_loc += grad_loc_stride;
731
+ }
732
+ }
733
+ }
734
+ }
735
+
736
+ template <typename scalar_t>
737
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
738
+ const scalar_t *grad_col,
739
+ const scalar_t *data_value,
740
+ const int64_t *data_spatial_shapes,
741
+ const int64_t *data_level_start_index,
742
+ const scalar_t *data_sampling_loc,
743
+ const scalar_t *data_attn_weight,
744
+ const int batch_size,
745
+ const int spatial_size,
746
+ const int num_heads,
747
+ const int channels,
748
+ const int num_levels,
749
+ const int num_query,
750
+ const int num_point,
751
+ scalar_t *grad_value,
752
+ scalar_t *grad_sampling_loc,
753
+ scalar_t *grad_attn_weight)
754
+ {
755
+ CUDA_KERNEL_LOOP(index, n)
756
+ {
757
+ extern __shared__ int _s[];
758
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
759
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
760
+ unsigned int tid = threadIdx.x;
761
+ int _temp = index;
762
+ const int c_col = _temp % channels;
763
+ _temp /= channels;
764
+ const int sampling_index = _temp;
765
+ const int m_col = _temp % num_heads;
766
+ _temp /= num_heads;
767
+ const int q_col = _temp % num_query;
768
+ _temp /= num_query;
769
+ const int b_col = _temp;
770
+
771
+ const scalar_t top_grad = grad_col[index];
772
+
773
+ int data_weight_ptr = sampling_index * num_levels * num_point;
774
+ int data_loc_w_ptr = data_weight_ptr << 1;
775
+ const int grad_sampling_ptr = data_weight_ptr;
776
+ grad_sampling_loc += grad_sampling_ptr << 1;
777
+ grad_attn_weight += grad_sampling_ptr;
778
+ const int grad_weight_stride = 1;
779
+ const int grad_loc_stride = 2;
780
+ const int qid_stride = num_heads * channels;
781
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
782
+
783
+ for (int l_col=0; l_col < num_levels; ++l_col)
784
+ {
785
+ const int level_start_id = data_level_start_index[l_col];
786
+ const int spatial_h_ptr = l_col << 1;
787
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
788
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
789
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
790
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
791
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
792
+
793
+ for (int p_col=0; p_col < num_point; ++p_col)
794
+ {
795
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
796
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
797
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
798
+
799
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
800
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
801
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
802
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
803
+ *(cache_grad_attn_weight+threadIdx.x)=0;
804
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
805
+ {
806
+ ms_deform_attn_col2im_bilinear(
807
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
808
+ top_grad, weight, grad_value_ptr,
809
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
810
+ }
811
+
812
+ __syncthreads();
813
+
814
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
815
+ {
816
+ if (tid < s) {
817
+ const unsigned int xid1 = tid << 1;
818
+ const unsigned int xid2 = (tid + s) << 1;
819
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
820
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
821
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
822
+ if (tid + (s << 1) < spre)
823
+ {
824
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
825
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
826
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
827
+ }
828
+ }
829
+ __syncthreads();
830
+ }
831
+
832
+ if (tid == 0)
833
+ {
834
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
835
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
836
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
837
+ }
838
+ __syncthreads();
839
+
840
+ data_weight_ptr += 1;
841
+ data_loc_w_ptr += 2;
842
+ grad_attn_weight += grad_weight_stride;
843
+ grad_sampling_loc += grad_loc_stride;
844
+ }
845
+ }
846
+ }
847
+ }
848
+
849
+
850
+ template <typename scalar_t>
851
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
852
+ const scalar_t *grad_col,
853
+ const scalar_t *data_value,
854
+ const int64_t *data_spatial_shapes,
855
+ const int64_t *data_level_start_index,
856
+ const scalar_t *data_sampling_loc,
857
+ const scalar_t *data_attn_weight,
858
+ const int batch_size,
859
+ const int spatial_size,
860
+ const int num_heads,
861
+ const int channels,
862
+ const int num_levels,
863
+ const int num_query,
864
+ const int num_point,
865
+ scalar_t *grad_value,
866
+ scalar_t *grad_sampling_loc,
867
+ scalar_t *grad_attn_weight)
868
+ {
869
+ CUDA_KERNEL_LOOP(index, n)
870
+ {
871
+ int _temp = index;
872
+ const int c_col = _temp % channels;
873
+ _temp /= channels;
874
+ const int sampling_index = _temp;
875
+ const int m_col = _temp % num_heads;
876
+ _temp /= num_heads;
877
+ const int q_col = _temp % num_query;
878
+ _temp /= num_query;
879
+ const int b_col = _temp;
880
+
881
+ const scalar_t top_grad = grad_col[index];
882
+
883
+ int data_weight_ptr = sampling_index * num_levels * num_point;
884
+ int data_loc_w_ptr = data_weight_ptr << 1;
885
+ const int grad_sampling_ptr = data_weight_ptr;
886
+ grad_sampling_loc += grad_sampling_ptr << 1;
887
+ grad_attn_weight += grad_sampling_ptr;
888
+ const int grad_weight_stride = 1;
889
+ const int grad_loc_stride = 2;
890
+ const int qid_stride = num_heads * channels;
891
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
892
+
893
+ for (int l_col=0; l_col < num_levels; ++l_col)
894
+ {
895
+ const int level_start_id = data_level_start_index[l_col];
896
+ const int spatial_h_ptr = l_col << 1;
897
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
898
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
899
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
900
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
901
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
902
+
903
+ for (int p_col=0; p_col < num_point; ++p_col)
904
+ {
905
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
906
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
907
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
908
+
909
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
910
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
911
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
912
+ {
913
+ ms_deform_attn_col2im_bilinear_gm(
914
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
915
+ top_grad, weight, grad_value_ptr,
916
+ grad_sampling_loc, grad_attn_weight);
917
+ }
918
+ data_weight_ptr += 1;
919
+ data_loc_w_ptr += 2;
920
+ grad_attn_weight += grad_weight_stride;
921
+ grad_sampling_loc += grad_loc_stride;
922
+ }
923
+ }
924
+ }
925
+ }
926
+
927
+
928
+ template <typename scalar_t>
929
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
930
+ const scalar_t* data_value,
931
+ const int64_t* data_spatial_shapes,
932
+ const int64_t* data_level_start_index,
933
+ const scalar_t* data_sampling_loc,
934
+ const scalar_t* data_attn_weight,
935
+ const int batch_size,
936
+ const int spatial_size,
937
+ const int num_heads,
938
+ const int channels,
939
+ const int num_levels,
940
+ const int num_query,
941
+ const int num_point,
942
+ scalar_t* data_col)
943
+ {
944
+ const int num_kernels = batch_size * num_query * num_heads * channels;
945
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
946
+ const int num_threads = CUDA_NUM_THREADS;
947
+ ms_deformable_im2col_gpu_kernel<scalar_t>
948
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
949
+ 0, stream>>>(
950
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
951
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
952
+
953
+ cudaError_t err = cudaGetLastError();
954
+ if (err != cudaSuccess)
955
+ {
956
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
957
+ }
958
+
959
+ }
960
+
961
+ template <typename scalar_t>
962
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
963
+ const scalar_t* grad_col,
964
+ const scalar_t* data_value,
965
+ const int64_t * data_spatial_shapes,
966
+ const int64_t * data_level_start_index,
967
+ const scalar_t * data_sampling_loc,
968
+ const scalar_t * data_attn_weight,
969
+ const int batch_size,
970
+ const int spatial_size,
971
+ const int num_heads,
972
+ const int channels,
973
+ const int num_levels,
974
+ const int num_query,
975
+ const int num_point,
976
+ scalar_t* grad_value,
977
+ scalar_t* grad_sampling_loc,
978
+ scalar_t* grad_attn_weight)
979
+ {
980
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
981
+ const int num_kernels = batch_size * num_query * num_heads * channels;
982
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
983
+ if (channels > 1024)
984
+ {
985
+ if ((channels & 1023) == 0)
986
+ {
987
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
988
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
989
+ num_threads*3*sizeof(scalar_t), stream>>>(
990
+ num_kernels,
991
+ grad_col,
992
+ data_value,
993
+ data_spatial_shapes,
994
+ data_level_start_index,
995
+ data_sampling_loc,
996
+ data_attn_weight,
997
+ batch_size,
998
+ spatial_size,
999
+ num_heads,
1000
+ channels,
1001
+ num_levels,
1002
+ num_query,
1003
+ num_point,
1004
+ grad_value,
1005
+ grad_sampling_loc,
1006
+ grad_attn_weight);
1007
+ }
1008
+ else
1009
+ {
1010
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1011
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1012
+ 0, stream>>>(
1013
+ num_kernels,
1014
+ grad_col,
1015
+ data_value,
1016
+ data_spatial_shapes,
1017
+ data_level_start_index,
1018
+ data_sampling_loc,
1019
+ data_attn_weight,
1020
+ batch_size,
1021
+ spatial_size,
1022
+ num_heads,
1023
+ channels,
1024
+ num_levels,
1025
+ num_query,
1026
+ num_point,
1027
+ grad_value,
1028
+ grad_sampling_loc,
1029
+ grad_attn_weight);
1030
+ }
1031
+ }
1032
+ else{
1033
+ switch(channels)
1034
+ {
1035
+ case 1:
1036
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1037
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1038
+ 0, stream>>>(
1039
+ num_kernels,
1040
+ grad_col,
1041
+ data_value,
1042
+ data_spatial_shapes,
1043
+ data_level_start_index,
1044
+ data_sampling_loc,
1045
+ data_attn_weight,
1046
+ batch_size,
1047
+ spatial_size,
1048
+ num_heads,
1049
+ channels,
1050
+ num_levels,
1051
+ num_query,
1052
+ num_point,
1053
+ grad_value,
1054
+ grad_sampling_loc,
1055
+ grad_attn_weight);
1056
+ break;
1057
+ case 2:
1058
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1059
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1060
+ 0, stream>>>(
1061
+ num_kernels,
1062
+ grad_col,
1063
+ data_value,
1064
+ data_spatial_shapes,
1065
+ data_level_start_index,
1066
+ data_sampling_loc,
1067
+ data_attn_weight,
1068
+ batch_size,
1069
+ spatial_size,
1070
+ num_heads,
1071
+ channels,
1072
+ num_levels,
1073
+ num_query,
1074
+ num_point,
1075
+ grad_value,
1076
+ grad_sampling_loc,
1077
+ grad_attn_weight);
1078
+ break;
1079
+ case 4:
1080
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1081
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1082
+ 0, stream>>>(
1083
+ num_kernels,
1084
+ grad_col,
1085
+ data_value,
1086
+ data_spatial_shapes,
1087
+ data_level_start_index,
1088
+ data_sampling_loc,
1089
+ data_attn_weight,
1090
+ batch_size,
1091
+ spatial_size,
1092
+ num_heads,
1093
+ channels,
1094
+ num_levels,
1095
+ num_query,
1096
+ num_point,
1097
+ grad_value,
1098
+ grad_sampling_loc,
1099
+ grad_attn_weight);
1100
+ break;
1101
+ case 8:
1102
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1103
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1104
+ 0, stream>>>(
1105
+ num_kernels,
1106
+ grad_col,
1107
+ data_value,
1108
+ data_spatial_shapes,
1109
+ data_level_start_index,
1110
+ data_sampling_loc,
1111
+ data_attn_weight,
1112
+ batch_size,
1113
+ spatial_size,
1114
+ num_heads,
1115
+ channels,
1116
+ num_levels,
1117
+ num_query,
1118
+ num_point,
1119
+ grad_value,
1120
+ grad_sampling_loc,
1121
+ grad_attn_weight);
1122
+ break;
1123
+ case 16:
1124
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1125
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1126
+ 0, stream>>>(
1127
+ num_kernels,
1128
+ grad_col,
1129
+ data_value,
1130
+ data_spatial_shapes,
1131
+ data_level_start_index,
1132
+ data_sampling_loc,
1133
+ data_attn_weight,
1134
+ batch_size,
1135
+ spatial_size,
1136
+ num_heads,
1137
+ channels,
1138
+ num_levels,
1139
+ num_query,
1140
+ num_point,
1141
+ grad_value,
1142
+ grad_sampling_loc,
1143
+ grad_attn_weight);
1144
+ break;
1145
+ case 32:
1146
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1147
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1148
+ 0, stream>>>(
1149
+ num_kernels,
1150
+ grad_col,
1151
+ data_value,
1152
+ data_spatial_shapes,
1153
+ data_level_start_index,
1154
+ data_sampling_loc,
1155
+ data_attn_weight,
1156
+ batch_size,
1157
+ spatial_size,
1158
+ num_heads,
1159
+ channels,
1160
+ num_levels,
1161
+ num_query,
1162
+ num_point,
1163
+ grad_value,
1164
+ grad_sampling_loc,
1165
+ grad_attn_weight);
1166
+ break;
1167
+ case 64:
1168
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1169
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1170
+ 0, stream>>>(
1171
+ num_kernels,
1172
+ grad_col,
1173
+ data_value,
1174
+ data_spatial_shapes,
1175
+ data_level_start_index,
1176
+ data_sampling_loc,
1177
+ data_attn_weight,
1178
+ batch_size,
1179
+ spatial_size,
1180
+ num_heads,
1181
+ channels,
1182
+ num_levels,
1183
+ num_query,
1184
+ num_point,
1185
+ grad_value,
1186
+ grad_sampling_loc,
1187
+ grad_attn_weight);
1188
+ break;
1189
+ case 128:
1190
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1191
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1192
+ 0, stream>>>(
1193
+ num_kernels,
1194
+ grad_col,
1195
+ data_value,
1196
+ data_spatial_shapes,
1197
+ data_level_start_index,
1198
+ data_sampling_loc,
1199
+ data_attn_weight,
1200
+ batch_size,
1201
+ spatial_size,
1202
+ num_heads,
1203
+ channels,
1204
+ num_levels,
1205
+ num_query,
1206
+ num_point,
1207
+ grad_value,
1208
+ grad_sampling_loc,
1209
+ grad_attn_weight);
1210
+ break;
1211
+ case 256:
1212
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1213
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1214
+ 0, stream>>>(
1215
+ num_kernels,
1216
+ grad_col,
1217
+ data_value,
1218
+ data_spatial_shapes,
1219
+ data_level_start_index,
1220
+ data_sampling_loc,
1221
+ data_attn_weight,
1222
+ batch_size,
1223
+ spatial_size,
1224
+ num_heads,
1225
+ channels,
1226
+ num_levels,
1227
+ num_query,
1228
+ num_point,
1229
+ grad_value,
1230
+ grad_sampling_loc,
1231
+ grad_attn_weight);
1232
+ break;
1233
+ case 512:
1234
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1235
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1236
+ 0, stream>>>(
1237
+ num_kernels,
1238
+ grad_col,
1239
+ data_value,
1240
+ data_spatial_shapes,
1241
+ data_level_start_index,
1242
+ data_sampling_loc,
1243
+ data_attn_weight,
1244
+ batch_size,
1245
+ spatial_size,
1246
+ num_heads,
1247
+ channels,
1248
+ num_levels,
1249
+ num_query,
1250
+ num_point,
1251
+ grad_value,
1252
+ grad_sampling_loc,
1253
+ grad_attn_weight);
1254
+ break;
1255
+ case 1024:
1256
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1257
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1258
+ 0, stream>>>(
1259
+ num_kernels,
1260
+ grad_col,
1261
+ data_value,
1262
+ data_spatial_shapes,
1263
+ data_level_start_index,
1264
+ data_sampling_loc,
1265
+ data_attn_weight,
1266
+ batch_size,
1267
+ spatial_size,
1268
+ num_heads,
1269
+ channels,
1270
+ num_levels,
1271
+ num_query,
1272
+ num_point,
1273
+ grad_value,
1274
+ grad_sampling_loc,
1275
+ grad_attn_weight);
1276
+ break;
1277
+ default:
1278
+ if (channels < 64)
1279
+ {
1280
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1281
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1282
+ num_threads*3*sizeof(scalar_t), stream>>>(
1283
+ num_kernels,
1284
+ grad_col,
1285
+ data_value,
1286
+ data_spatial_shapes,
1287
+ data_level_start_index,
1288
+ data_sampling_loc,
1289
+ data_attn_weight,
1290
+ batch_size,
1291
+ spatial_size,
1292
+ num_heads,
1293
+ channels,
1294
+ num_levels,
1295
+ num_query,
1296
+ num_point,
1297
+ grad_value,
1298
+ grad_sampling_loc,
1299
+ grad_attn_weight);
1300
+ }
1301
+ else
1302
+ {
1303
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1304
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1305
+ num_threads*3*sizeof(scalar_t), stream>>>(
1306
+ num_kernels,
1307
+ grad_col,
1308
+ data_value,
1309
+ data_spatial_shapes,
1310
+ data_level_start_index,
1311
+ data_sampling_loc,
1312
+ data_attn_weight,
1313
+ batch_size,
1314
+ spatial_size,
1315
+ num_heads,
1316
+ channels,
1317
+ num_levels,
1318
+ num_query,
1319
+ num_point,
1320
+ grad_value,
1321
+ grad_sampling_loc,
1322
+ grad_attn_weight);
1323
+ }
1324
+ }
1325
+ }
1326
+ cudaError_t err = cudaGetLastError();
1327
+ if (err != cudaSuccess)
1328
+ {
1329
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1330
+ }
1331
+
1332
+ }
imcui/third_party/rdd/RDD/models/ops/src/ms_deform_attn.h ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #pragma once
17
+
18
+ #include "cpu/ms_deform_attn_cpu.h"
19
+
20
+ #ifdef WITH_CUDA
21
+ #include "cuda/ms_deform_attn_cuda.h"
22
+ #endif
23
+
24
+
25
+ at::Tensor
26
+ ms_deform_attn_forward(
27
+ const at::Tensor &value,
28
+ const at::Tensor &spatial_shapes,
29
+ const at::Tensor &level_start_index,
30
+ const at::Tensor &sampling_loc,
31
+ const at::Tensor &attn_weight,
32
+ const int im2col_step)
33
+ {
34
+ if (value.type().is_cuda())
35
+ {
36
+ #ifdef WITH_CUDA
37
+ return ms_deform_attn_cuda_forward(
38
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
39
+ #else
40
+ AT_ERROR("Not compiled with GPU support");
41
+ #endif
42
+ }
43
+ AT_ERROR("Not implemented on the CPU");
44
+ }
45
+
46
+ std::vector<at::Tensor>
47
+ ms_deform_attn_backward(
48
+ const at::Tensor &value,
49
+ const at::Tensor &spatial_shapes,
50
+ const at::Tensor &level_start_index,
51
+ const at::Tensor &sampling_loc,
52
+ const at::Tensor &attn_weight,
53
+ const at::Tensor &grad_output,
54
+ const int im2col_step)
55
+ {
56
+ if (value.type().is_cuda())
57
+ {
58
+ #ifdef WITH_CUDA
59
+ return ms_deform_attn_cuda_backward(
60
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
61
+ #else
62
+ AT_ERROR("Not compiled with GPU support");
63
+ #endif
64
+ }
65
+ AT_ERROR("Not implemented on the CPU");
66
+ }
67
+
imcui/third_party/rdd/RDD/models/ops/src/vision.cpp ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ /*!
12
+ * Copyright (c) Facebook, Inc. and its affiliates.
13
+ * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
14
+ */
15
+
16
+ #include "ms_deform_attn.h"
17
+
18
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
19
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
20
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
21
+ }
imcui/third_party/rdd/RDD/models/ops/test.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ # Copyright (c) Facebook, Inc. and its affiliates.
10
+ # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import print_function
14
+ from __future__ import division
15
+
16
+ import time
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.autograd import gradcheck
20
+
21
+ from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
22
+
23
+
24
+ N, M, D = 1, 2, 2
25
+ Lq, L, P = 2, 2, 2
26
+ shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
27
+ level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
28
+ S = sum([(H*W).item() for H, W in shapes])
29
+
30
+
31
+ torch.manual_seed(3)
32
+
33
+
34
+ @torch.no_grad()
35
+ def check_forward_equal_with_pytorch_double():
36
+ value = torch.rand(N, S, M, D).cuda() * 0.01
37
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
38
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
39
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
40
+ im2col_step = 2
41
+ output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
42
+ output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
43
+ fwdok = torch.allclose(output_cuda, output_pytorch)
44
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
45
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
46
+
47
+ print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
48
+
49
+
50
+ @torch.no_grad()
51
+ def check_forward_equal_with_pytorch_float():
52
+ value = torch.rand(N, S, M, D).cuda() * 0.01
53
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
54
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
55
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
56
+ im2col_step = 2
57
+ output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
58
+ output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
59
+ fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
60
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
61
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
62
+
63
+ print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
64
+
65
+
66
+ def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
67
+
68
+ value = torch.rand(N, S, M, channels).cuda() * 0.01
69
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
70
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
71
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
72
+ im2col_step = 2
73
+ func = MSDeformAttnFunction.apply
74
+
75
+ value.requires_grad = grad_value
76
+ sampling_locations.requires_grad = grad_sampling_loc
77
+ attention_weights.requires_grad = grad_attn_weight
78
+
79
+ gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
80
+
81
+ print(f'* {gradok} check_gradient_numerical(D={channels})')
82
+
83
+
84
+ if __name__ == '__main__':
85
+ check_forward_equal_with_pytorch_double()
86
+ check_forward_equal_with_pytorch_float()
87
+
88
+ for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
89
+ check_gradient_numerical(channels, True, True, True)
90
+
91
+
92
+
imcui/third_party/rdd/RDD/models/position_encoding.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from ..utils.misc import NestedTensor
5
+
6
+ class PositionEmbeddingSine(nn.Module):
7
+ """
8
+ This is a more standard version of the position embedding, very similar to the one
9
+ used by the Attention is all you need paper, generalized to work on images.
10
+ """
11
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
12
+ super().__init__()
13
+ self.num_pos_feats = num_pos_feats
14
+ self.temperature = temperature
15
+ self.normalize = normalize
16
+ if scale is not None and normalize is False:
17
+ raise ValueError("normalize should be True if scale is passed")
18
+ if scale is None:
19
+ scale = 2 * math.pi
20
+ self.scale = scale
21
+
22
+ def forward(self, tensor_list: NestedTensor):
23
+ x = tensor_list.tensors
24
+ mask = tensor_list.mask
25
+ assert mask is not None
26
+ not_mask = ~mask
27
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
28
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
29
+ if self.normalize:
30
+ eps = 1e-6
31
+ y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
32
+ x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
33
+
34
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
35
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
36
+
37
+ pos_x = x_embed[:, :, :, None] / dim_t
38
+ pos_y = y_embed[:, :, :, None] / dim_t
39
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
40
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
41
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
42
+ return pos
43
+
44
+ def build_position_encoding(config):
45
+ N_steps = config['hidden_dim'] // 2
46
+ # TODO find a better way of exposing other arguments
47
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
48
+ return position_embedding
imcui/third_party/rdd/RDD/models/soft_detect.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ALIKE: https://github.com/Shiaoming/ALIKE
2
+ import torch
3
+ from torch import nn
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+
7
+
8
+ # coordinates system
9
+ # ------------------------------> [ x: range=-1.0~1.0; w: range=0~W ]
10
+ # | -----------------------------
11
+ # | | |
12
+ # | | |
13
+ # | | |
14
+ # | | image |
15
+ # | | |
16
+ # | | |
17
+ # | | |
18
+ # | |---------------------------|
19
+ # v
20
+ # [ y: range=-1.0~1.0; h: range=0~H ]
21
+
22
+ def simple_nms(scores, nms_radius: int):
23
+ """ Fast Non-maximum suppression to remove nearby points """
24
+ assert (nms_radius >= 0)
25
+
26
+ def max_pool(x):
27
+ return torch.nn.functional.max_pool2d(
28
+ x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius)
29
+
30
+ zeros = torch.zeros_like(scores)
31
+ max_mask = scores == max_pool(scores)
32
+
33
+ for _ in range(2):
34
+ supp_mask = max_pool(max_mask.float()) > 0
35
+ supp_scores = torch.where(supp_mask, zeros, scores)
36
+ new_max_mask = supp_scores == max_pool(supp_scores)
37
+ max_mask = max_mask | (new_max_mask & (~supp_mask))
38
+ return torch.where(max_mask, scores, zeros)
39
+
40
+
41
+ """
42
+ "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
43
+ https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/
44
+ """
45
+
46
+ import torch
47
+ import torch.nn as nn
48
+ import torch.nn.functional as F
49
+
50
+ class InterpolateSparse2d(nn.Module):
51
+ """ Efficiently interpolate tensor at given sparse 2D positions. """
52
+ def __init__(self, mode = 'bicubic', align_corners = False):
53
+ super().__init__()
54
+ self.mode = mode
55
+ self.align_corners = align_corners
56
+
57
+ def normgrid(self, x, H, W):
58
+ """ Normalize coords to [-1,1]. """
59
+ return 2. * (x/(torch.tensor([W-1, H-1], device = x.device, dtype = x.dtype))) - 1.
60
+
61
+ def forward(self, x, pos, H, W):
62
+ """
63
+ Input
64
+ x: [B, C, H, W] feature tensor
65
+ pos: [B, N, 2] tensor of positions
66
+ H, W: int, original resolution of input 2d positions -- used in normalization [-1,1]
67
+
68
+ Returns
69
+ [B, N, C] sampled channels at 2d positions
70
+ """
71
+ grid = self.normgrid(pos, H, W).unsqueeze(-2).to(x.dtype)
72
+ x = F.grid_sample(x, grid, mode = self.mode , align_corners = False)
73
+ return x.permute(0,2,3,1).squeeze(-2)
74
+
75
+
76
+ class SoftDetect(nn.Module):
77
+ def __init__(self, radius=2, top_k=0, scores_th=0.2, n_limit=20000):
78
+ """
79
+ Args:
80
+ radius: soft detection radius, kernel size is (2 * radius + 1)
81
+ top_k: top_k > 0: return top k keypoints
82
+ scores_th: top_k <= 0 threshold mode: scores_th > 0: return keypoints with scores>scores_th
83
+ else: return keypoints with scores > scores.mean()
84
+ n_limit: max number of keypoint in threshold mode
85
+ """
86
+ super().__init__()
87
+ self.radius = radius
88
+ self.top_k = top_k
89
+ self.scores_th = scores_th
90
+ self.n_limit = n_limit
91
+ self.kernel_size = 2 * self.radius + 1
92
+ self.temperature = 0.1 # tuned temperature
93
+ self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius)
94
+ self.sample_descriptor = InterpolateSparse2d('bicubic')
95
+ # local xy grid
96
+ x = torch.linspace(-self.radius, self.radius, self.kernel_size)
97
+ # (kernel_size*kernel_size) x 2 : (w,h)
98
+ self.hw_grid = torch.stack(torch.meshgrid([x, x])).view(2, -1).t()[:, [1, 0]]
99
+
100
+ def detect_keypoints(self, scores_map, normalized_coordinates=True):
101
+ b, c, h, w = scores_map.shape
102
+ scores_nograd = scores_map.detach()
103
+
104
+ # nms_scores = simple_nms(scores_nograd, self.radius)
105
+ nms_scores = simple_nms(scores_nograd, 2)
106
+
107
+ # remove border
108
+ nms_scores[:, :, :self.radius + 1, :] = 0
109
+ nms_scores[:, :, :, :self.radius + 1] = 0
110
+ nms_scores[:, :, h - self.radius:, :] = 0
111
+ nms_scores[:, :, :, w - self.radius:] = 0
112
+
113
+ # detect keypoints without grad
114
+ if self.top_k > 0:
115
+ topk = torch.topk(nms_scores.view(b, -1), self.top_k)
116
+ indices_keypoints = topk.indices # B x top_k
117
+ else:
118
+ if self.scores_th > 0:
119
+ masks = nms_scores > self.scores_th
120
+ if masks.sum() == 0:
121
+ th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
122
+ masks = nms_scores > th.reshape(b, 1, 1, 1)
123
+ else:
124
+ th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
125
+ masks = nms_scores > th.reshape(b, 1, 1, 1)
126
+ masks = masks.reshape(b, -1)
127
+
128
+ indices_keypoints = [] # list, B x (any size)
129
+ scores_view = scores_nograd.reshape(b, -1)
130
+ for mask, scores in zip(masks, scores_view):
131
+ indices = mask.nonzero(as_tuple=False)[:, 0]
132
+ if len(indices) > self.n_limit:
133
+ kpts_sc = scores[indices]
134
+ sort_idx = kpts_sc.sort(descending=True)[1]
135
+ sel_idx = sort_idx[:self.n_limit]
136
+ indices = indices[sel_idx]
137
+ indices_keypoints.append(indices)
138
+
139
+ # detect soft keypoints with grad backpropagation
140
+ patches = self.unfold(scores_map) # B x (kernel**2) x (H*W)
141
+ self.hw_grid = self.hw_grid.to(patches) # to device
142
+ keypoints = []
143
+ scoredispersitys = []
144
+ kptscores = []
145
+ for b_idx in range(b):
146
+ patch = patches[b_idx].t() # (H*W) x (kernel**2)
147
+ indices_kpt = indices_keypoints[b_idx] # one dimension vector, say its size is M
148
+ patch_scores = patch[indices_kpt] # M x (kernel**2)
149
+
150
+ # max is detached to prevent undesired backprop loops in the graph
151
+ max_v = patch_scores.max(dim=1).values.detach()[:, None]
152
+ x_exp = ((patch_scores - max_v) / self.temperature).exp() # M * (kernel**2), in [0, 1]
153
+
154
+ # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
155
+ xy_residual = x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None] # Soft-argmax, Mx2
156
+
157
+ hw_grid_dist2 = torch.norm((self.hw_grid[None, :, :] - xy_residual[:, None, :]) / self.radius,
158
+ dim=-1) ** 2
159
+ scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
160
+
161
+ # compute result keypoints
162
+ keypoints_xy_nms = torch.stack([indices_kpt % w, indices_kpt // w], dim=1) # Mx2
163
+ keypoints_xy = keypoints_xy_nms + xy_residual
164
+ if normalized_coordinates:
165
+ keypoints_xy = keypoints_xy / keypoints_xy.new_tensor([w - 1, h - 1]) * 2 - 1 # (w,h) -> (-1~1,-1~1)
166
+
167
+ kptscore = torch.nn.functional.grid_sample(scores_map[b_idx].unsqueeze(0), keypoints_xy.view(1, 1, -1, 2),
168
+ mode='bilinear', align_corners=True)[0, 0, 0, :] # CxN
169
+
170
+ keypoints.append(keypoints_xy)
171
+ scoredispersitys.append(scoredispersity)
172
+ kptscores.append(kptscore)
173
+
174
+ return keypoints, scoredispersitys, kptscores
175
+
176
+ def forward(self, scores_map, normalized_coordinates=True):
177
+ """
178
+ :param scores_map: Bx1xHxW
179
+ :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1.0 ~ 1.0
180
+ """
181
+ B, _, H, W = scores_map.shape
182
+
183
+ keypoints, scoredispersitys, kptscores = self.detect_keypoints(scores_map,
184
+ normalized_coordinates)
185
+
186
+ # keypoints: B M 2
187
+ # scoredispersitys:
188
+ return keypoints, kptscores, scoredispersitys
189
+
190
+ import torch
191
+ import torch.nn as nn
192
+
193
+ class Detect(nn.Module):
194
+ def __init__(self, stride=4, top_k=0, scores_th=0, n_limit=20000):
195
+ super().__init__()
196
+ self.stride = stride
197
+ self.top_k = top_k
198
+ self.scores_th = scores_th
199
+ self.n_limit = n_limit
200
+
201
+ def forward(self, scores, coords, w, h):
202
+ """
203
+ scores: B x N x 1 (keypoint confidence scores)
204
+ coords: B x N x 2 (offsets within stride x stride window)
205
+ w, h: Image dimensions
206
+ """
207
+ b, n, _ = scores.shape
208
+ kpts_list = []
209
+ scores_list = []
210
+
211
+ for b_idx in range(b):
212
+ score = scores[b_idx].squeeze(-1) # Shape: (N,)
213
+ coord = coords[b_idx] # Shape: (N, 2)
214
+
215
+ # Apply score thresholding
216
+ if self.scores_th >= 0:
217
+ valid = score > self.scores_th
218
+ else:
219
+ valid = score > score.mean()
220
+
221
+ valid_indices = valid.nonzero(as_tuple=True)[0] # Get valid indices
222
+ if valid_indices.numel() == 0:
223
+ kpts_list.append(torch.empty((0, 2), device=scores.device))
224
+ scores_list.append(torch.empty((0,), device=scores.device))
225
+ continue
226
+
227
+ # Compute keypoint locations in original image space
228
+ i_ids = valid_indices # Indices where keypoints exist
229
+ kpts = torch.stack([i_ids % w, i_ids // w], dim=1).to(torch.float) * self.stride # Grid position
230
+ kpts += coord[i_ids] * self.stride # Apply offset
231
+
232
+ # Normalize keypoints to [-1, 1] range
233
+ kpts = (kpts / torch.tensor([w - 1, h - 1], device=kpts.device, dtype=kpts.dtype)) * 2 - 1
234
+
235
+ # Filter top-k keypoints if needed
236
+ scores_valid = score[valid_indices]
237
+ if self.top_k > 0 and len(kpts) > self.top_k:
238
+ topk = torch.topk(scores_valid, self.top_k, dim=0)
239
+ kpts = kpts[topk.indices]
240
+ scores_valid = topk.values
241
+ elif self.top_k < 0:
242
+ if len(kpts) > self.n_limit:
243
+ sorted_idx = scores_valid.argsort(descending=True)[:self.n_limit]
244
+ kpts = kpts[sorted_idx]
245
+ scores_valid = scores_valid[sorted_idx]
246
+
247
+ kpts_list.append(kpts)
248
+ scores_list.append(scores_valid)
249
+
250
+ return kpts_list, scores_list
imcui/third_party/rdd/RDD/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .misc import *
imcui/third_party/rdd/RDD/utils/misc.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Deformable DETR
2
+ # https://github.com/fundamentalvision/Deformable-DETR
3
+ # ------------------------------------------------------------------------
4
+ # Deformable DETR
5
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
6
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
7
+ # ------------------------------------------------------------------------
8
+ # Modified from DETR (https://github.com/facebookresearch/detr)
9
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
10
+ # ------------------------------------------------------------------------
11
+ import os
12
+ import subprocess
13
+ import time
14
+ from collections import defaultdict, deque
15
+ import datetime
16
+ import pickle
17
+ from typing import Optional, List
18
+ import yaml
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.distributed as dist
22
+ from torch import Tensor
23
+
24
+ # needed due to empty tensor bug in pytorch and torchvision 0.5
25
+ import torchvision
26
+ if float(torchvision.__version__.split('.')[1]) < 0.5:
27
+ import math
28
+ from torchvision.ops.misc import _NewEmptyTensorOp
29
+ def _check_size_scale_factor(dim, size, scale_factor):
30
+ # type: (int, Optional[List[int]], Optional[float]) -> None
31
+ if size is None and scale_factor is None:
32
+ raise ValueError("either size or scale_factor should be defined")
33
+ if size is not None and scale_factor is not None:
34
+ raise ValueError("only one of size or scale_factor should be defined")
35
+ if not (scale_factor is not None and len(scale_factor) != dim):
36
+ raise ValueError(
37
+ "scale_factor shape must match input shape. "
38
+ "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
39
+ )
40
+ def _output_size(dim, input, size, scale_factor):
41
+ # type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int]
42
+ assert dim == 2
43
+ _check_size_scale_factor(dim, size, scale_factor)
44
+ if size is not None:
45
+ return size
46
+ # if dim is not 2 or scale_factor is iterable use _ntuple instead of concat
47
+ assert scale_factor is not None and isinstance(scale_factor, (int, float))
48
+ scale_factors = [scale_factor, scale_factor]
49
+ # math.floor might return float in py2.7
50
+ return [
51
+ int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)
52
+ ]
53
+ elif float(torchvision.__version__.split('.')[1]) < 7:
54
+ from torchvision.ops import _new_empty_tensor
55
+ from torchvision.ops.misc import _output_size
56
+
57
+
58
+ class SmoothedValue(object):
59
+ """Track a series of values and provide access to smoothed values over a
60
+ window or the global series average.
61
+ """
62
+
63
+ def __init__(self, window_size=20, fmt=None):
64
+ if fmt is None:
65
+ fmt = "{median:.4f} ({global_avg:.4f})"
66
+ self.deque = deque(maxlen=window_size)
67
+ self.total = 0.0
68
+ self.count = 0
69
+ self.fmt = fmt
70
+
71
+ def update(self, value, n=1):
72
+ self.deque.append(value)
73
+ self.count += n
74
+ self.total += value * n
75
+
76
+ def synchronize_between_processes(self):
77
+ """
78
+ Warning: does not synchronize the deque!
79
+ """
80
+ if not is_dist_avail_and_initialized():
81
+ return
82
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
83
+ dist.barrier()
84
+ dist.all_reduce(t)
85
+ t = t.tolist()
86
+ self.count = int(t[0])
87
+ self.total = t[1]
88
+
89
+ @property
90
+ def median(self):
91
+ d = torch.tensor(list(self.deque))
92
+ return d.median().item()
93
+
94
+ @property
95
+ def avg(self):
96
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
97
+ return d.mean().item()
98
+
99
+ @property
100
+ def global_avg(self):
101
+ return self.total / self.count
102
+
103
+ @property
104
+ def max(self):
105
+ return max(self.deque)
106
+
107
+ @property
108
+ def value(self):
109
+ return self.deque[-1]
110
+
111
+ def __str__(self):
112
+ return self.fmt.format(
113
+ median=self.median,
114
+ avg=self.avg,
115
+ global_avg=self.global_avg,
116
+ max=self.max,
117
+ value=self.value)
118
+
119
+
120
+ def all_gather(data):
121
+ """
122
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
123
+ Args:
124
+ data: any picklable object
125
+ Returns:
126
+ list[data]: list of data gathered from each rank
127
+ """
128
+ world_size = get_world_size()
129
+ if world_size == 1:
130
+ return [data]
131
+
132
+ # serialized to a Tensor
133
+ buffer = pickle.dumps(data)
134
+ storage = torch.ByteStorage.from_buffer(buffer)
135
+ tensor = torch.ByteTensor(storage).to("cuda")
136
+
137
+ # obtain Tensor size of each rank
138
+ local_size = torch.tensor([tensor.numel()], device="cuda")
139
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
140
+ dist.all_gather(size_list, local_size)
141
+ size_list = [int(size.item()) for size in size_list]
142
+ max_size = max(size_list)
143
+
144
+ # receiving Tensor from all ranks
145
+ # we pad the tensor because torch all_gather does not support
146
+ # gathering tensors of different shapes
147
+ tensor_list = []
148
+ for _ in size_list:
149
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
150
+ if local_size != max_size:
151
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
152
+ tensor = torch.cat((tensor, padding), dim=0)
153
+ dist.all_gather(tensor_list, tensor)
154
+
155
+ data_list = []
156
+ for size, tensor in zip(size_list, tensor_list):
157
+ buffer = tensor.cpu().numpy().tobytes()[:size]
158
+ data_list.append(pickle.loads(buffer))
159
+
160
+ return data_list
161
+
162
+
163
+ def reduce_dict(input_dict, average=True):
164
+ """
165
+ Args:
166
+ input_dict (dict): all the values will be reduced
167
+ average (bool): whether to do average or sum
168
+ Reduce the values in the dictionary from all processes so that all processes
169
+ have the averaged results. Returns a dict with the same fields as
170
+ input_dict, after reduction.
171
+ """
172
+ world_size = get_world_size()
173
+ if world_size < 2:
174
+ return input_dict
175
+ with torch.no_grad():
176
+ names = []
177
+ values = []
178
+ # sort the keys so that they are consistent across processes
179
+ for k in sorted(input_dict.keys()):
180
+ names.append(k)
181
+ values.append(input_dict[k])
182
+ values = torch.stack(values, dim=0)
183
+ dist.all_reduce(values)
184
+ if average:
185
+ values /= world_size
186
+ reduced_dict = {k: v for k, v in zip(names, values)}
187
+ return reduced_dict
188
+
189
+
190
+ class MetricLogger(object):
191
+ def __init__(self, delimiter="\t"):
192
+ self.meters = defaultdict(SmoothedValue)
193
+ self.delimiter = delimiter
194
+
195
+ def update(self, **kwargs):
196
+ for k, v in kwargs.items():
197
+ if isinstance(v, torch.Tensor):
198
+ v = v.item()
199
+ assert isinstance(v, (float, int))
200
+ self.meters[k].update(v)
201
+
202
+ def __getattr__(self, attr):
203
+ if attr in self.meters:
204
+ return self.meters[attr]
205
+ if attr in self.__dict__:
206
+ return self.__dict__[attr]
207
+ raise AttributeError("'{}' object has no attribute '{}'".format(
208
+ type(self).__name__, attr))
209
+
210
+ def __str__(self):
211
+ loss_str = []
212
+ for name, meter in self.meters.items():
213
+ loss_str.append(
214
+ "{}: {}".format(name, str(meter))
215
+ )
216
+ return self.delimiter.join(loss_str)
217
+
218
+ def synchronize_between_processes(self):
219
+ for meter in self.meters.values():
220
+ meter.synchronize_between_processes()
221
+
222
+ def add_meter(self, name, meter):
223
+ self.meters[name] = meter
224
+
225
+ def log_every(self, iterable, print_freq, header=None):
226
+ i = 0
227
+ if not header:
228
+ header = ''
229
+ start_time = time.time()
230
+ end = time.time()
231
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
232
+ data_time = SmoothedValue(fmt='{avg:.4f}')
233
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
234
+ if torch.cuda.is_available():
235
+ log_msg = self.delimiter.join([
236
+ header,
237
+ '[{0' + space_fmt + '}/{1}]',
238
+ 'eta: {eta}',
239
+ '{meters}',
240
+ 'time: {time}',
241
+ 'data: {data}',
242
+ 'max mem: {memory:.0f}'
243
+ ])
244
+ else:
245
+ log_msg = self.delimiter.join([
246
+ header,
247
+ '[{0' + space_fmt + '}/{1}]',
248
+ 'eta: {eta}',
249
+ '{meters}',
250
+ 'time: {time}',
251
+ 'data: {data}'
252
+ ])
253
+ MB = 1024.0 * 1024.0
254
+ for obj in iterable:
255
+ data_time.update(time.time() - end)
256
+ yield obj
257
+ iter_time.update(time.time() - end)
258
+ if i % print_freq == 0 or i == len(iterable) - 1:
259
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
260
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
261
+ if torch.cuda.is_available():
262
+ print(log_msg.format(
263
+ i, len(iterable), eta=eta_string,
264
+ meters=str(self),
265
+ time=str(iter_time), data=str(data_time),
266
+ memory=torch.cuda.max_memory_allocated() / MB))
267
+ else:
268
+ print(log_msg.format(
269
+ i, len(iterable), eta=eta_string,
270
+ meters=str(self),
271
+ time=str(iter_time), data=str(data_time)))
272
+ i += 1
273
+ end = time.time()
274
+ total_time = time.time() - start_time
275
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
276
+ print('{} Total time: {} ({:.4f} s / it)'.format(
277
+ header, total_time_str, total_time / len(iterable)))
278
+
279
+
280
+ def get_sha():
281
+ cwd = os.path.dirname(os.path.abspath(__file__))
282
+
283
+ def _run(command):
284
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
285
+ sha = 'N/A'
286
+ diff = "clean"
287
+ branch = 'N/A'
288
+ try:
289
+ sha = _run(['git', 'rev-parse', 'HEAD'])
290
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
291
+ diff = _run(['git', 'diff-index', 'HEAD'])
292
+ diff = "has uncommited changes" if diff else "clean"
293
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
294
+ except Exception:
295
+ pass
296
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
297
+ return message
298
+
299
+
300
+ def collate_fn(batch):
301
+ batch = list(zip(*batch))
302
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
303
+ return tuple(batch)
304
+
305
+
306
+ def _max_by_axis(the_list):
307
+ # type: (List[List[int]]) -> List[int]
308
+ maxes = the_list[0]
309
+ for sublist in the_list[1:]:
310
+ for index, item in enumerate(sublist):
311
+ maxes[index] = max(maxes[index], item)
312
+ return maxes
313
+
314
+
315
+ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
316
+ # TODO make this more general
317
+ if tensor_list[0].ndim == 3:
318
+ # TODO make it support different-sized images
319
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
320
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
321
+ batch_shape = [len(tensor_list)] + max_size
322
+ b, c, h, w = batch_shape
323
+ dtype = tensor_list[0].dtype
324
+ device = tensor_list[0].device
325
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
326
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
327
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
328
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
329
+ m[: img.shape[1], :img.shape[2]] = False
330
+ else:
331
+ raise ValueError('not supported')
332
+ return NestedTensor(tensor, mask)
333
+
334
+
335
+ class NestedTensor(object):
336
+ def __init__(self, tensors, mask: Optional[Tensor]):
337
+ self.tensors = tensors
338
+ self.mask = mask
339
+
340
+ def to(self, device, non_blocking=False):
341
+ # type: (Device) -> NestedTensor # noqa
342
+ cast_tensor = self.tensors.to(device, non_blocking=non_blocking)
343
+ mask = self.mask
344
+ if mask is not None:
345
+ assert mask is not None
346
+ cast_mask = mask.to(device, non_blocking=non_blocking)
347
+ else:
348
+ cast_mask = None
349
+ return NestedTensor(cast_tensor, cast_mask)
350
+
351
+ def record_stream(self, *args, **kwargs):
352
+ self.tensors.record_stream(*args, **kwargs)
353
+ if self.mask is not None:
354
+ self.mask.record_stream(*args, **kwargs)
355
+
356
+ def decompose(self):
357
+ return self.tensors, self.mask
358
+
359
+ def __repr__(self):
360
+ return str(self.tensors)
361
+
362
+
363
+ def setup_for_distributed(is_master):
364
+ """
365
+ This function disables printing when not in master process
366
+ """
367
+ import builtins as __builtin__
368
+ builtin_print = __builtin__.print
369
+
370
+ def print(*args, **kwargs):
371
+ force = kwargs.pop('force', False)
372
+ if is_master or force:
373
+ builtin_print(*args, **kwargs)
374
+
375
+ __builtin__.print = print
376
+
377
+
378
+ def is_dist_avail_and_initialized():
379
+ if not dist.is_available():
380
+ return False
381
+ if not dist.is_initialized():
382
+ return False
383
+ return True
384
+
385
+
386
+ def get_world_size():
387
+ if not is_dist_avail_and_initialized():
388
+ return 1
389
+ return dist.get_world_size()
390
+
391
+
392
+ def get_rank():
393
+ if not is_dist_avail_and_initialized():
394
+ return 0
395
+ return dist.get_rank()
396
+
397
+
398
+ def get_local_size():
399
+ if not is_dist_avail_and_initialized():
400
+ return 1
401
+ return int(os.environ['LOCAL_SIZE'])
402
+
403
+
404
+ def get_local_rank():
405
+ if not is_dist_avail_and_initialized():
406
+ return 0
407
+ return int(os.environ['LOCAL_RANK'])
408
+
409
+
410
+ def is_main_process():
411
+ return get_rank() == 0
412
+
413
+
414
+ def save_on_master(*args, **kwargs):
415
+ if is_main_process():
416
+ torch.save(*args, **kwargs)
417
+
418
+
419
+ def init_distributed_mode(args):
420
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
421
+ args.rank = int(os.environ["RANK"])
422
+ args.world_size = int(os.environ['WORLD_SIZE'])
423
+ args.gpu = int(os.environ['LOCAL_RANK'])
424
+ args.dist_url = 'env://'
425
+ os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
426
+ elif 'SLURM_PROCID' in os.environ:
427
+ proc_id = int(os.environ['SLURM_PROCID'])
428
+ ntasks = int(os.environ['SLURM_NTASKS'])
429
+ node_list = os.environ['SLURM_NODELIST']
430
+ num_gpus = torch.cuda.device_count()
431
+ addr = subprocess.getoutput(
432
+ 'scontrol show hostname {} | head -n1'.format(node_list))
433
+ os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
434
+ os.environ['MASTER_ADDR'] = addr
435
+ os.environ['WORLD_SIZE'] = str(ntasks)
436
+ os.environ['RANK'] = str(proc_id)
437
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
438
+ os.environ['LOCAL_SIZE'] = str(num_gpus)
439
+ args.dist_url = 'env://'
440
+ args.world_size = ntasks
441
+ args.rank = proc_id
442
+ args.gpu = proc_id % num_gpus
443
+ else:
444
+ print('Not using distributed mode')
445
+ args.distributed = False
446
+ return
447
+
448
+ args.distributed = True
449
+
450
+ torch.cuda.set_device(args.gpu)
451
+ args.dist_backend = 'nccl'
452
+ print('| distributed init (rank {}): {}'.format(
453
+ args.rank, args.dist_url), flush=True)
454
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
455
+ world_size=args.world_size, rank=args.rank)
456
+ torch.distributed.barrier()
457
+ setup_for_distributed(args.rank == 0)
458
+
459
+
460
+ @torch.no_grad()
461
+ def accuracy(output, target, topk=(1,)):
462
+ """Computes the precision@k for the specified values of k"""
463
+ if target.numel() == 0:
464
+ return [torch.zeros([], device=output.device)]
465
+ maxk = max(topk)
466
+ batch_size = target.size(0)
467
+
468
+ _, pred = output.topk(maxk, 1, True, True)
469
+ pred = pred.t()
470
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
471
+
472
+ res = []
473
+ for k in topk:
474
+ correct_k = correct[:k].view(-1).float().sum(0)
475
+ res.append(correct_k.mul_(100.0 / batch_size))
476
+ return res
477
+
478
+
479
+ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
480
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
481
+ """
482
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
483
+ This will eventually be supported natively by PyTorch, and this
484
+ class can go away.
485
+ """
486
+ if float(torchvision.__version__[:3]) < 0.7:
487
+ if input.numel() > 0:
488
+ return torch.nn.functional.interpolate(
489
+ input, size, scale_factor, mode, align_corners
490
+ )
491
+
492
+ output_shape = _output_size(2, input, size, scale_factor)
493
+ output_shape = list(input.shape[:-2]) + list(output_shape)
494
+ if float(torchvision.__version__[:3]) < 0.5:
495
+ return _NewEmptyTensorOp.apply(input, output_shape)
496
+ return _new_empty_tensor(input, output_shape)
497
+ else:
498
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
499
+
500
+
501
+ def get_total_grad_norm(parameters, norm_type=2):
502
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
503
+ norm_type = float(norm_type)
504
+ device = parameters[0].grad.device
505
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
506
+ norm_type)
507
+ return total_norm
508
+
509
+ def inverse_sigmoid(x, eps=1e-5):
510
+ x = x.clamp(min=0, max=1)
511
+ x1 = x.clamp(min=eps)
512
+ x2 = (1 - x).clamp(min=eps)
513
+ return torch.log(x1/x2)
514
+
515
+
516
+ def to_pixel_coords(flow, h1, w1):
517
+ flow = (
518
+ torch.stack(
519
+ (
520
+ w1 * (flow[..., 0] + 1) / 2,
521
+ h1 * (flow[..., 1] + 1) / 2,
522
+ ),
523
+ axis=-1,
524
+ )
525
+ )
526
+ return flow
527
+
528
+ def read_config(file_path):
529
+ with open(file_path, 'r') as file:
530
+ config = yaml.safe_load(file)
531
+ return config
imcui/third_party/rdd/README.md ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## RDD: Robust Feature Detector and Descriptor using Deformable Transformer (CVPR 2025)
2
+ [Gonglin Chen](https://xtcpete.com/) · [Tianwen Fu](https://twfu.me/) · [Haiwei Chen](https://scholar.google.com/citations?user=LVWRssoAAAAJ&hl=en) · [Wenbin Teng](https://wbteng9526.github.io/) · [Hanyuan Xiao](https://corneliushsiao.github.io/index.html) · [Yajie Zhao](https://ict.usc.edu/about-us/leadership/research-leadership/yajie-zhao/)
3
+
4
+ [Project Page](https://xtcpete.github.io/rdd/)
5
+
6
+ ## Table of Contents
7
+ - [Updates](#updates)
8
+ - [Installation](#installation)
9
+ - [Usage](#usage)
10
+ - [Inference](#inference)
11
+ - [Evaluation](#evaluation)
12
+ - [Training](#training)
13
+ - [Citation](#citation)
14
+ - [License](#license)
15
+ - [Acknowledgements](#acknowledgements)
16
+
17
+ ## Updates
18
+
19
+ - SfM reconstruction through [COLMAP](https://github.com/colmap/colmap.git) added. We provide a ready-to-use [notebook](./demo_sfm.ipynb) for a simple example. Code adopted from [hloc](https://github.com/cvg/Hierarchical-Localization.git).
20
+
21
+ - Training code and new weights released.
22
+
23
+ - We have updated the training code compared to what was described in the paper. In the original setup, the RDD was trained on the MegaDepth and Air-to-Ground datasets by resizing all images to the training resolution. In this release, we retrained RDD on MegaDepth only, using a combination of resizing and cropping, a strategy used by [ALIKE](https://github.com/Shiaoming/ALIKE). This change significantly improves robustness.
24
+
25
+ <table>
26
+ <tr>
27
+ <th></th>
28
+ <th colspan="3">MegaDepth-1500</th>
29
+ <th colspan="3">MegaDepth-View</th>
30
+ <th colspan="3">Air-to-Ground</th>
31
+ </tr>
32
+ <tr>
33
+ <td></td>
34
+ <td>AUC 5&deg</td><td>AUC 10&deg</td><td>AUC 20&deg</td>
35
+ <td>AUC 5&deg</td><td>AUC 10&deg</td><td>AUC 20&deg</td>
36
+ <td>AUC 5&deg</td><td>AUC 10&deg</td><td>AUC 20&deg</td>
37
+ </tr>
38
+ <tr>
39
+ <td>RDD-v2</td>
40
+ <td>52.4</td><td>68.5</td><td>80.1</td>
41
+ <td>52.0</td><td>67.1</td><td>78.2</td>
42
+ <td>45.8</td><td>58.6</td><td>71.0</td>
43
+ </tr>
44
+ <tr>
45
+ <td>RDD-v1</td>
46
+ <td>48.2</td><td>65.2</td><td>78.3</td>
47
+ <td>38.3</td><td>53.1</td><td>65.6</td>
48
+ <td>41.4</td><td>56.0</td><td>67.8</td>
49
+ </tr>
50
+ <tr>
51
+ <td>RDD-v2+LG</td>
52
+ <td>53.3</td><td>69.8</td><td>82.0</td>
53
+ <td>59.0</td><td>74.2</td><td>84.0</td>
54
+ <td>54.8</td><td>69.0</td><td>79.1</td>
55
+ </tr>
56
+ <tr>
57
+ <td>RDD-v1+LG</td>
58
+ <td>52.3</td><td>68.9</td><td>81.8</td>
59
+ <td>54.2</td><td>69.3</td><td>80.3</td>
60
+ <td>55.1</td><td>68.9</td><td>78.9</td>
61
+ </tr>
62
+ </table>
63
+
64
+ ## Installation
65
+
66
+ ```bash
67
+ git clone --recursive https://github.com/xtcpete/rdd
68
+ cd RDD
69
+
70
+ # Create conda env
71
+ conda create -n rdd python=3.10 pip
72
+ conda activate rdd
73
+
74
+ # Install CUDA
75
+ conda install -c nvidia/label/cuda-11.8.0 cuda-toolkit
76
+ # Install torch
77
+ pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu118
78
+ # Install all dependencies
79
+ pip install -r requirements.txt
80
+ # Compile custom operations
81
+ cd ./RDD/models/ops
82
+ pip install -e .
83
+ ```
84
+
85
+ We provide the [download link](https://drive.google.com/drive/folders/1QgVaqm4iTUCqbWb7_Fi6mX09EHTId0oA?usp=sharing) to:
86
+ - the MegaDepth-1500 test set
87
+ - the MegaDepth-View test set
88
+ - the Air-to-Ground test set
89
+ - 2 pretrained models, RDD and LightGlue for matching RDD
90
+
91
+ Create and unzip downloaded test data to the `data` folder.
92
+
93
+ Create and add weights to the `weights` folder and you are ready to go.
94
+
95
+ ## Usage
96
+ For your convenience, we provide a ready-to-use [notebook](./demo_matching.ipynb) for some examples.
97
+
98
+ ### Inference
99
+
100
+ ```python
101
+ from RDD.RDD import build
102
+
103
+ RDD_model = build()
104
+
105
+ output = RDD_model.extract(torch.randn(1, 3, 480, 640))
106
+ ```
107
+
108
+ ### Evaluation
109
+
110
+ Please note that due to the different GPU architectures and the stochastic nature of RANSAC, you may observe slightly different results; however, they should be very close to those reported in the paper. To reproduce the number in paper, use v1 weights instead.
111
+
112
+ Results can be visualized by passing argument --plot
113
+
114
+ **MegaDepth-1500**
115
+
116
+ ```bash
117
+ # Sparse matching
118
+ python ./benchmarks/mega_1500.py
119
+
120
+ # Dense matching
121
+ python ./benchmarks/mega_1500.py --method dense
122
+
123
+ # LightGlue
124
+ python ./benchmarks/mega_1500.py --method lightglue
125
+ ```
126
+
127
+ **MegaDepth-View**
128
+
129
+ ```bash
130
+ # Sparse matching
131
+ python ./benchmarks/mega_view.py
132
+
133
+ # Dense matching
134
+ python ./benchmarks/mega_view.py --method dense
135
+
136
+ # LightGlue
137
+ python ./benchmarks/mega_view.py --method lightglue
138
+ ```
139
+
140
+ **Air-to-Ground**
141
+
142
+ ```bash
143
+ # Sparse matching
144
+ python ./benchmarks/air_ground.py
145
+
146
+ # Dense matching
147
+ python ./benchmarks/air_ground.py --method dense
148
+
149
+ # LightGlue
150
+ python ./benchmarks/air_ground.py --method lightglue
151
+ ```
152
+
153
+ ### Training
154
+
155
+ 1. Download MegaDepth dataset using [download.sh](./data/megadepth/download.sh) and megadepth_indices from [LoFTR](https://github.com/zju3dv/LoFTR/blob/master/docs/TRAINING.md#download-datasets). Then the MegaDepth root folder should look like the following:
156
+ ```bash
157
+ ./data/megadepth/megadepth_indices # indices
158
+ ./data/megadepth/depth_undistorted # depth maps
159
+ ./data/megadepth/Undistorted_SfM # images and poses
160
+ ./data/megadepth/scene_info # indices for training LightGlue
161
+ ```
162
+ 2. Then you can train RDD in two steps; Descriptor first
163
+ ```bash
164
+ # distributed training with 8 gpus
165
+ python -m training.train --ckpt_save_path ./ckpt_descriptor --distributed --batch_size 32
166
+
167
+ # single gpu
168
+ python -m training.train --ckpt_save_path ./ckpt_descriptor
169
+ ```
170
+ and then Detector
171
+ ```bash
172
+ python -m training.train --ckpt_save_path ./ckpt_detector --weights ./ckpt_descriptor/RDD_best.pth --train_detector --training_res 480
173
+ ```
174
+
175
+ I am working on recollecting the Air-to-Ground dataset because of licensing issues.
176
+
177
+ ## Citation
178
+ ```
179
+ @inproceedings{gonglin2025rdd,
180
+ title = {RDD: Robust Feature Detector and Descriptor using Deformable Transformer},
181
+ author = {Chen, Gonglin and Fu, Tianwen and Chen, Haiwei and Teng, Wenbin and Xiao, Hanyuan and Zhao, Yajie},
182
+ booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
183
+ year = {2025}
184
+ }
185
+ ```
186
+
187
+
188
+ ## License
189
+ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE)
190
+
191
+ ## Acknowledgements
192
+
193
+ We thank these great repositories: [ALIKE](https://github.com/Shiaoming/ALIKE), [LoFTR](https://github.com/zju3dv/LoFTR), [DeDoDe](https://github.com/Parskatt/DeDoDe), [XFeat](https://github.com/verlab/accelerated_features), [LightGlue](https://github.com/cvg/LightGlue), [Kornia](https://github.com/kornia/kornia), and [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR), and many other inspiring works in the community.
194
+
195
+ LightGlue is trained with [Glue Factory](https://github.com/cvg/glue-factory).
196
+
197
+ Supported by the Intelligence Advanced Research Projects Activity (IARPA) via Department of Interior/Interior Business Center (DOI/IBC) contract number 140D0423C0075. The U.S. Government is authorized to reproduce and distribute reprints for governmental purposes notwithstanding any copyright annotation thereon. Disclaimer: The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies or endorsements, either expressed or implied, of IARPA, DOI/IBC, or the U.S. Government. We would like to thank Yayue Chen for her help with visualization.
imcui/third_party/rdd/benchmarks/air_ground.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append(".")
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ import tqdm
7
+ import cv2
8
+ import argparse
9
+ import matplotlib.pyplot as plt
10
+ import matplotlib
11
+ from RDD.RDD_helper import RDD_helper
12
+ from RDD.RDD import build
13
+ import os
14
+ from benchmarks.utils import pose_auc, angle_error_vec, angle_error_mat, symmetric_epipolar_distance, compute_symmetrical_epipolar_errors, compute_pose_error, compute_relative_pose, estimate_pose, dynamic_alpha
15
+
16
+ def make_matching_figure(
17
+ img0, img1, mkpts0, mkpts1, color,
18
+ kpts0=None, kpts1=None, text=[], dpi=75, path=None):
19
+ # draw image pair
20
+ assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
21
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
22
+ axes[0].imshow(img0, cmap='gray')
23
+ axes[1].imshow(img1, cmap='gray')
24
+ for i in range(2): # clear all frames
25
+ axes[i].get_yaxis().set_ticks([])
26
+ axes[i].get_xaxis().set_ticks([])
27
+ for spine in axes[i].spines.values():
28
+ spine.set_visible(False)
29
+ plt.tight_layout(pad=1)
30
+
31
+ if kpts0 is not None:
32
+ assert kpts1 is not None
33
+ axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
34
+ axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
35
+
36
+ # draw matches
37
+ if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
38
+ fig.canvas.draw()
39
+ transFigure = fig.transFigure.inverted()
40
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
41
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
42
+ fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
43
+ (fkpts0[i, 1], fkpts1[i, 1]),
44
+ transform=fig.transFigure, c=color[i], linewidth=1)
45
+ for i in range(len(mkpts0))]
46
+
47
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
48
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
49
+
50
+ # put txts
51
+ txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
52
+ fig.text(
53
+ 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
54
+ fontsize=15, va='top', ha='left', color=txt_color)
55
+
56
+ # save or return figure
57
+ if path:
58
+ plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
59
+ plt.close()
60
+ else:
61
+ return fig
62
+
63
+ def error_colormap(err, thr, alpha=1.0):
64
+ assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
65
+ x = 1 - np.clip(err / (thr * 2), 0, 1)
66
+ return np.clip(
67
+ np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
68
+
69
+ def _make_evaluation_figure(img0, img1, kpts0, kpts1, epi_errs, e_t, e_R, alpha='dynamic', path=None):
70
+ conf_thr = 1e-4
71
+
72
+ img0 = np.array(img0)
73
+ img1 = np.array(img1)
74
+
75
+ kpts0 = kpts0
76
+ kpts1 = kpts1
77
+
78
+ epi_errs = epi_errs.cpu().numpy()
79
+ correct_mask = epi_errs < conf_thr
80
+ precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
81
+ n_correct = np.sum(correct_mask)
82
+
83
+ # recall might be larger than 1, since the calculation of conf_matrix_gt
84
+ # uses groundtruth depths and camera poses, but epipolar distance is used here.
85
+
86
+ # matching info
87
+ if alpha == 'dynamic':
88
+ alpha = dynamic_alpha(len(correct_mask))
89
+ color = error_colormap(epi_errs, conf_thr, alpha=alpha)
90
+
91
+ text = [
92
+ f'#Matches {len(kpts0)}',
93
+ f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
94
+ f'e_t: {e_t:.2f} | e_R: {e_R:.2f}',
95
+ ]
96
+
97
+ # make the figure
98
+ figure = make_matching_figure(img0, img1, kpts0, kpts1,
99
+ color, text=text, path=path)
100
+ return figure
101
+
102
+ class AirGroundPoseMNNBenchmark:
103
+ def __init__(self, data_root="./data/air_ground", scene_names = None) -> None:
104
+ if scene_names is None:
105
+ self.scene_names = [
106
+ "indices.npz",
107
+ ]
108
+ # self.scene_names = ["0022_0.5_0.7.npz",]
109
+ else:
110
+ self.scene_names = scene_names
111
+ self.scenes = [
112
+ np.load(f"{data_root}/{scene}", allow_pickle=True)
113
+ for scene in self.scene_names
114
+ ]
115
+ self.data_root = data_root
116
+
117
+ def benchmark(self, model_helper, model_name = None, scale_intrinsics = False, calibrated = True, plot_every_iter=10, plot=False, method='sparse'):
118
+ with torch.no_grad():
119
+ data_root = self.data_root
120
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
121
+ thresholds = [5, 10, 20]
122
+ for scene_ind in range(len(self.scenes)):
123
+ import os
124
+ scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
125
+ scene = self.scenes[scene_ind]
126
+ indices = scene['pair_info']
127
+ idx = 0
128
+ for pair in tqdm.tqdm(indices):
129
+
130
+ pairs = pair['pair_names']
131
+ K0 = pair['intrinsic'][0].copy().astype(np.float32)
132
+ T0 = pair['pose'][0].copy().astype(np.float32)
133
+ R0, t0 = T0[:3, :3], T0[:3, 3]
134
+ K1 = pair['intrinsic'][1].copy().astype(np.float32)
135
+ T1 = pair['pose'][1].copy().astype(np.float32)
136
+ R1, t1 = T1[:3, :3], T1[:3, 3]
137
+ R, t = compute_relative_pose(R0, t0, R1, t1)
138
+ T0_to_1 = np.concatenate((R,t[:,None]), axis=-1)
139
+ im_A_path = f"{data_root}/images/{pairs[0]}"
140
+ im_B_path = f"{data_root}/images/{pairs[1]}"
141
+
142
+ im_A = cv2.imread(im_A_path)
143
+ im_B = cv2.imread(im_B_path)
144
+
145
+ if method == 'dense':
146
+ kpts0, kpts1, conf = model_helper.match_dense(im_A, im_B, thr=0.01, resize=1600)
147
+ elif method == 'lightglue':
148
+ kpts0, kpts1, conf = model_helper.match_lg(im_A, im_B, thr=0.01, resize=1600)
149
+ elif method == 'sparse':
150
+ kpts0, kpts1, conf = model_helper.match(im_A, im_B, thr=0.01, resize=1600)
151
+ else:
152
+ raise ValueError(f"Invalid method {method}")
153
+
154
+ im_A = Image.open(im_A_path)
155
+ w0, h0 = im_A.size
156
+ im_B = Image.open(im_B_path)
157
+ w1, h1 = im_B.size
158
+ if scale_intrinsics:
159
+ scale0 = 840 / max(w0, h0)
160
+ scale1 = 840 / max(w1, h1)
161
+ w0, h0 = scale0 * w0, scale0 * h0
162
+ w1, h1 = scale1 * w1, scale1 * h1
163
+ K0, K1 = K0.copy(), K1.copy()
164
+ K0[:2] = K0[:2] * scale0
165
+ K1[:2] = K1[:2] * scale1
166
+
167
+ threshold = 0.5
168
+ if calibrated:
169
+ norm_threshold = threshold / (np.mean(np.abs(K0[:2, :2])) + np.mean(np.abs(K1[:2, :2])))
170
+ ret = estimate_pose(
171
+ kpts0,
172
+ kpts1,
173
+ K0,
174
+ K1,
175
+ norm_threshold,
176
+ conf=0.99999,
177
+ )
178
+ if ret is not None:
179
+ R_est, t_est, mask = ret
180
+ T0_to_1_est = np.concatenate((R_est, t_est), axis=-1) #
181
+ T0_to_1 = np.concatenate((R, t[:,None]), axis=-1)
182
+ e_t, e_R = compute_pose_error(T0_to_1_est, R, t)
183
+
184
+ epi_errs = compute_symmetrical_epipolar_errors(T0_to_1, kpts0, kpts1, K0, K1)
185
+ if scene_ind % plot_every_iter == 0 and plot:
186
+
187
+ if not os.path.exists(f'outputs/air_ground/{model_name}_{method}'):
188
+ os.mkdir(f'outputs/air_ground/{model_name}_{method}')
189
+ name = f'outputs/air_ground/{model_name}_{method}/{scene_name}_{idx}.png'
190
+ _make_evaluation_figure(im_A, im_B, kpts0, kpts1, epi_errs, e_t, e_R, path=name)
191
+ e_pose = max(e_t, e_R)
192
+
193
+ tot_e_t.append(e_t)
194
+ tot_e_R.append(e_R)
195
+ tot_e_pose.append(e_pose)
196
+ idx += 1
197
+
198
+ tot_e_pose = np.array(tot_e_pose)
199
+ auc = pose_auc(tot_e_pose, thresholds)
200
+ acc_5 = (tot_e_pose < 5).mean()
201
+ acc_10 = (tot_e_pose < 10).mean()
202
+ acc_15 = (tot_e_pose < 15).mean()
203
+ acc_20 = (tot_e_pose < 20).mean()
204
+ map_5 = acc_5
205
+ map_10 = np.mean([acc_5, acc_10])
206
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
207
+ print(f"{model_name} auc: {auc}")
208
+ return {
209
+ "auc_5": auc[0],
210
+ "auc_10": auc[1],
211
+ "auc_20": auc[2],
212
+ "map_5": map_5,
213
+ "map_10": map_10,
214
+ "map_20": map_20,
215
+ }
216
+
217
+
218
+
219
+ def parse_arguments():
220
+ parser = argparse.ArgumentParser(description="Testing script.")
221
+
222
+ parser.add_argument("--data_root", type=str, default="./data/air_ground", help="Path to the Air-to-Ground test dataset.")
223
+
224
+ parser.add_argument("--weights", type=str, default="./weights/RDD-v2.pth", help="Path to the model checkpoint.")
225
+
226
+ parser.add_argument("--plot", action="store_true", help="Whether to plot the results.")
227
+
228
+ parser.add_argument("--method", type=str, default="sparse", help="Method for matching.")
229
+
230
+ return parser.parse_args()
231
+
232
+ if __name__ == "__main__":
233
+ args = parse_arguments()
234
+
235
+ if not os.path.exists('outputs'):
236
+ os.mkdir('outputs')
237
+ if not os.path.exists(f'outputs/air_ground'):
238
+ os.mkdir(f'outputs/air_ground')
239
+ model = build(weights=args.weights)
240
+ benchmark = AirGroundPoseMNNBenchmark(data_root=args.data_root)
241
+ model.eval()
242
+ model_helper = RDD_helper(model)
243
+ with torch.no_grad():
244
+ method = args.method
245
+ out = benchmark.benchmark(model_helper, model_name='RDD', plot_every_iter=1, plot=args.plot, method=method)
246
+ with open(f'outputs/air_ground/RDD_{method}.txt', 'w') as f:
247
+ f.write(str(out))
imcui/third_party/rdd/benchmarks/mega_1500.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append(".")
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ import tqdm
7
+ import cv2
8
+ import argparse
9
+ from RDD.RDD_helper import RDD_helper
10
+ from RDD.RDD import build
11
+ import matplotlib.pyplot as plt
12
+ import matplotlib
13
+ import os
14
+ from benchmarks.utils import pose_auc, angle_error_vec, angle_error_mat, symmetric_epipolar_distance, compute_symmetrical_epipolar_errors, compute_pose_error, compute_relative_pose, estimate_pose, dynamic_alpha
15
+
16
+ def make_matching_figure(
17
+ img0, img1, mkpts0, mkpts1, color,
18
+ kpts0=None, kpts1=None, text=[], dpi=75, path=None):
19
+ # draw image pair
20
+ assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
21
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
22
+ axes[0].imshow(img0, cmap='gray')
23
+ axes[1].imshow(img1, cmap='gray')
24
+ for i in range(2): # clear all frames
25
+ axes[i].get_yaxis().set_ticks([])
26
+ axes[i].get_xaxis().set_ticks([])
27
+ for spine in axes[i].spines.values():
28
+ spine.set_visible(False)
29
+ plt.tight_layout(pad=1)
30
+
31
+ if kpts0 is not None:
32
+ assert kpts1 is not None
33
+ axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
34
+ axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
35
+
36
+ # draw matches
37
+ if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
38
+ fig.canvas.draw()
39
+ transFigure = fig.transFigure.inverted()
40
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
41
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
42
+ fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
43
+ (fkpts0[i, 1], fkpts1[i, 1]),
44
+ transform=fig.transFigure, c=color[i], linewidth=1)
45
+ for i in range(len(mkpts0))]
46
+
47
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
48
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
49
+
50
+ # put txts
51
+ txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
52
+ fig.text(
53
+ 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
54
+ fontsize=15, va='top', ha='left', color=txt_color)
55
+
56
+ # save or return figure
57
+ if path:
58
+ plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
59
+ plt.close()
60
+ else:
61
+ return fig
62
+
63
+ def error_colormap(err, thr, alpha=1.0):
64
+ assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
65
+ x = 1 - np.clip(err / (thr * 2), 0, 1)
66
+ return np.clip(
67
+ np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
68
+
69
+ def _make_evaluation_figure(img0, img1, kpts0, kpts1, epi_errs, e_t, e_R, alpha='dynamic', path=None):
70
+ conf_thr = 1e-4
71
+
72
+ img0 = np.array(img0)
73
+ img1 = np.array(img1)
74
+
75
+ kpts0 = kpts0
76
+ kpts1 = kpts1
77
+
78
+ epi_errs = epi_errs.cpu().numpy()
79
+ correct_mask = epi_errs < conf_thr
80
+ precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
81
+ n_correct = np.sum(correct_mask)
82
+
83
+ # recall might be larger than 1, since the calculation of conf_matrix_gt
84
+ # uses groundtruth depths and camera poses, but epipolar distance is used here.
85
+
86
+ # matching info
87
+ if alpha == 'dynamic':
88
+ alpha = dynamic_alpha(len(correct_mask))
89
+ color = error_colormap(epi_errs, conf_thr, alpha=alpha)
90
+
91
+ text = [
92
+ f'#Matches {len(kpts0)}',
93
+ f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
94
+ f'e_t: {e_t:.2f} | e_R: {e_R:.2f}',
95
+ ]
96
+
97
+ # make the figure
98
+ figure = make_matching_figure(img0, img1, kpts0, kpts1,
99
+ color, text=text, path=path)
100
+ return figure
101
+
102
+ class MegaDepthPoseMNNBenchmark:
103
+ def __init__(self, data_root="./megadepth_test_1500", scene_names = None) -> None:
104
+ if scene_names is None:
105
+ self.scene_names = [
106
+ "0015_0.1_0.3.npz",
107
+ "0015_0.3_0.5.npz",
108
+ "0022_0.1_0.3.npz",
109
+ "0022_0.3_0.5.npz",
110
+ "0022_0.5_0.7.npz",
111
+ ]
112
+
113
+ else:
114
+ self.scene_names = scene_names
115
+ self.scenes = [
116
+ np.load(f"{data_root}/{scene}", allow_pickle=True)
117
+ for scene in self.scene_names
118
+ ]
119
+ self.data_root = data_root
120
+
121
+ def benchmark(self, model_helper, model_name = None, scale_intrinsics = False, calibrated = True, plot_every_iter=1, plot=False, method='sparse'):
122
+
123
+ with torch.no_grad():
124
+ data_root = self.data_root
125
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
126
+ thresholds = [5, 10, 20]
127
+ for scene_ind in range(len(self.scenes)):
128
+ import os
129
+ scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
130
+ print(f"Processing {scene_name}")
131
+ scene = self.scenes[scene_ind]
132
+ pairs = scene["pair_infos"]
133
+ intrinsics = scene["intrinsics"]
134
+ poses = scene["poses"]
135
+ im_paths = scene["image_paths"]
136
+ pair_inds = range(len(pairs))
137
+ for pairind in tqdm.tqdm(pair_inds):
138
+ idx0, idx1 = pairs[pairind][0]
139
+ K0 = intrinsics[idx0].copy()
140
+ T0 = poses[idx0].copy()
141
+ R0, t0 = T0[:3, :3], T0[:3, 3]
142
+ K1 = intrinsics[idx1].copy()
143
+ T1 = poses[idx1].copy()
144
+ R1, t1 = T1[:3, :3], T1[:3, 3]
145
+ R, t = compute_relative_pose(R0, t0, R1, t1)
146
+ T0_to_1 = np.concatenate((R,t[:,None]), axis=-1)
147
+ im_A_path = f"{data_root}/{im_paths[idx0]}"
148
+ im_B_path = f"{data_root}/{im_paths[idx1]}"
149
+
150
+ im_A = cv2.imread(im_A_path)
151
+ im_B = cv2.imread(im_B_path)
152
+
153
+ if method == 'dense':
154
+ kpts0, kpts1, conf = model_helper.match_dense(im_A, im_B, thr=0.01, resize=1600)
155
+ elif method == 'lightglue':
156
+ kpts0, kpts1, conf = model_helper.match_lg(im_A, im_B, thr=0.01, resize=1600)
157
+ elif method == 'sparse':
158
+ kpts0, kpts1, conf = model_helper.match(im_A, im_B, thr=0.01, resize=1600)
159
+ else:
160
+ kpts0, kpts1, conf = model_helper.match_3rd_party(im_A, im_B, thr=0.01, resize=1600, model=method)
161
+
162
+ im_A = Image.open(im_A_path)
163
+ w0, h0 = im_A.size
164
+ im_B = Image.open(im_B_path)
165
+ w1, h1 = im_B.size
166
+ if scale_intrinsics:
167
+ scale0 = 840 / max(w0, h0)
168
+ scale1 = 840 / max(w1, h1)
169
+ w0, h0 = scale0 * w0, scale0 * h0
170
+ w1, h1 = scale1 * w1, scale1 * h1
171
+ K0, K1 = K0.copy(), K1.copy()
172
+ K0[:2] = K0[:2] * scale0
173
+ K1[:2] = K1[:2] * scale1
174
+
175
+
176
+ threshold = 0.5
177
+ if calibrated:
178
+ norm_threshold = threshold / (np.mean(np.abs(K0[:2, :2])) + np.mean(np.abs(K1[:2, :2])))
179
+ ret = estimate_pose(
180
+ kpts0,
181
+ kpts1,
182
+ K0,
183
+ K1,
184
+ norm_threshold,
185
+ conf=0.99999,
186
+ )
187
+ if ret is not None:
188
+ R_est, t_est, mask = ret
189
+ T0_to_1_est = np.concatenate((R_est, t_est), axis=-1) #
190
+ T0_to_1 = np.concatenate((R, t[:,None]), axis=-1)
191
+ e_t, e_R = compute_pose_error(T0_to_1_est, R, t)
192
+
193
+ epi_errs = compute_symmetrical_epipolar_errors(T0_to_1, kpts0, kpts1, K0, K1)
194
+ if scene_ind % plot_every_iter == 0 and plot:
195
+
196
+ if not os.path.exists(f'outputs/mega_1500/{model_name}_{method}'):
197
+ os.mkdir(f'outputs/mega_1500/{model_name}_{method}')
198
+ name = f'outputs/mega_1500/{model_name}_{method}/{scene_name}_{pairind}.png'
199
+ _make_evaluation_figure(im_A, im_B, kpts0, kpts1, epi_errs, e_t, e_R, path=name)
200
+ e_pose = max(e_t, e_R)
201
+
202
+ tot_e_t.append(e_t)
203
+ tot_e_R.append(e_R)
204
+ tot_e_pose.append(e_pose)
205
+
206
+ tot_e_pose = np.array(tot_e_pose)
207
+ auc = pose_auc(tot_e_pose, thresholds)
208
+ acc_5 = (tot_e_pose < 5).mean()
209
+ acc_10 = (tot_e_pose < 10).mean()
210
+ acc_15 = (tot_e_pose < 15).mean()
211
+ acc_20 = (tot_e_pose < 20).mean()
212
+ map_5 = acc_5
213
+ map_10 = np.mean([acc_5, acc_10])
214
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
215
+ print(f"{model_name} auc: {auc}")
216
+ return {
217
+ "auc_5": auc[0],
218
+ "auc_10": auc[1],
219
+ "auc_20": auc[2],
220
+ "map_5": map_5,
221
+ "map_10": map_10,
222
+ "map_20": map_20,
223
+ }
224
+
225
+
226
+ def parse_arguments():
227
+ parser = argparse.ArgumentParser(description="Testing script.")
228
+
229
+ parser.add_argument("--data_root", type=str, default="./data/megadepth_test_1500", help="Path to the MegaDepth dataset.")
230
+
231
+ parser.add_argument("--weights", type=str, default="./weights/RDD-v2.pth", help="Path to the model checkpoint.")
232
+
233
+ parser.add_argument("--plot", action="store_true", help="Whether to plot the results.")
234
+
235
+ parser.add_argument("--method", type=str, default="sparse", help="Method for matching.")
236
+
237
+ return parser.parse_args()
238
+
239
+ if __name__ == "__main__":
240
+ args = parse_arguments()
241
+ if not os.path.exists('outputs'):
242
+ os.mkdir('outputs')
243
+
244
+ if not os.path.exists(f'outputs/mega_1500'):
245
+ os.mkdir(f'outputs/mega_1500')
246
+
247
+ model = build(weights=args.weights)
248
+ benchmark = MegaDepthPoseMNNBenchmark(data_root=args.data_root)
249
+ model.eval()
250
+ model_helper = RDD_helper(model)
251
+ with torch.no_grad():
252
+ method = args.method
253
+ out = benchmark.benchmark(model_helper, model_name='RDD', plot_every_iter=1, plot=args.plot, method=method)
254
+ with open(f'outputs/mega_1500/RDD_{method}.txt', 'w') as f:
255
+ f.write(str(out))
imcui/third_party/rdd/benchmarks/mega_view.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append(".")
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ import tqdm
7
+ import cv2
8
+ import argparse
9
+ import matplotlib.pyplot as plt
10
+ import matplotlib
11
+ from RDD.RDD import build
12
+ from RDD.RDD_helper import RDD_helper
13
+ import os
14
+ from benchmarks.utils import pose_auc, angle_error_vec, angle_error_mat, symmetric_epipolar_distance, compute_symmetrical_epipolar_errors, compute_pose_error, compute_relative_pose, estimate_pose, dynamic_alpha
15
+
16
+ def make_matching_figure(
17
+ img0, img1, mkpts0, mkpts1, color,
18
+ kpts0=None, kpts1=None, text=[], dpi=75, path=None):
19
+ # draw image pair
20
+ assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
21
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
22
+ axes[0].imshow(img0, cmap='gray')
23
+ axes[1].imshow(img1, cmap='gray')
24
+ for i in range(2): # clear all frames
25
+ axes[i].get_yaxis().set_ticks([])
26
+ axes[i].get_xaxis().set_ticks([])
27
+ for spine in axes[i].spines.values():
28
+ spine.set_visible(False)
29
+ plt.tight_layout(pad=1)
30
+
31
+ if kpts0 is not None:
32
+ assert kpts1 is not None
33
+ axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
34
+ axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
35
+
36
+ # draw matches
37
+ if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
38
+ fig.canvas.draw()
39
+ transFigure = fig.transFigure.inverted()
40
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
41
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
42
+ fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
43
+ (fkpts0[i, 1], fkpts1[i, 1]),
44
+ transform=fig.transFigure, c=color[i], linewidth=1)
45
+ for i in range(len(mkpts0))]
46
+
47
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
48
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
49
+
50
+ # put txts
51
+ txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
52
+ fig.text(
53
+ 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
54
+ fontsize=15, va='top', ha='left', color=txt_color)
55
+
56
+ # save or return figure
57
+ if path:
58
+ plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
59
+ plt.close()
60
+ else:
61
+ return fig
62
+
63
+ def error_colormap(err, thr, alpha=1.0):
64
+ assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
65
+ x = 1 - np.clip(err / (thr * 2), 0, 1)
66
+ return np.clip(
67
+ np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
68
+
69
+ def _make_evaluation_figure(img0, img1, kpts0, kpts1, epi_errs, e_t, e_R, alpha='dynamic', path=None):
70
+ conf_thr = 1e-4
71
+
72
+ img0 = np.array(img0)
73
+ img1 = np.array(img1)
74
+
75
+ kpts0 = kpts0
76
+ kpts1 = kpts1
77
+
78
+ epi_errs = epi_errs.cpu().numpy()
79
+ correct_mask = epi_errs < conf_thr
80
+ precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
81
+ n_correct = np.sum(correct_mask)
82
+
83
+ # recall might be larger than 1, since the calculation of conf_matrix_gt
84
+ # uses groundtruth depths and camera poses, but epipolar distance is used here.
85
+
86
+ # matching info
87
+ if alpha == 'dynamic':
88
+ alpha = dynamic_alpha(len(correct_mask))
89
+ color = error_colormap(epi_errs, conf_thr, alpha=alpha)
90
+
91
+ text = [
92
+ f'#Matches {len(kpts0)}',
93
+ f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
94
+ f'e_t: {e_t:.2f} | e_R: {e_R:.2f}',
95
+ ]
96
+
97
+ # make the figure
98
+ figure = make_matching_figure(img0, img1, kpts0, kpts1,
99
+ color, text=text, path=path)
100
+ return figure
101
+
102
+ class MegaDepthPoseMNNBenchmark:
103
+ def __init__(self, data_root="./megadepth_test_1500", scene_names = None) -> None:
104
+ if scene_names is None:
105
+ self.scene_names = [
106
+ "hard_indices.npz",
107
+ ]
108
+ # self.scene_names = ["0022_0.5_0.7.npz",]
109
+ else:
110
+ self.scene_names = scene_names
111
+ self.scenes = [
112
+ np.load(f"{data_root}/{scene}", allow_pickle=True)
113
+ for scene in self.scene_names
114
+ ]
115
+ self.data_root = data_root
116
+
117
+ def benchmark(self, model_helper, model_name = None, scale_intrinsics = False, calibrated = True, plot_every_iter=1, plot=False, method='sparse'):
118
+ with torch.no_grad():
119
+ data_root = self.data_root
120
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
121
+ thresholds = [5, 10, 20]
122
+ for scene_ind in range(len(self.scenes)):
123
+ scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
124
+ scene = self.scenes[scene_ind]
125
+ indices = scene['indices']
126
+ idx = 0
127
+
128
+ for pair in tqdm.tqdm(indices):
129
+
130
+ pairs = pair['pair_names']
131
+ K0 = pair['intrisinic'][0].copy().astype(np.float32)
132
+ T0 = pair['pose'][0].copy().astype(np.float32)
133
+ R0, t0 = T0[:3, :3], T0[:3, 3]
134
+ K1 = pair['intrisinic'][1].copy().astype(np.float32)
135
+ T1 = pair['pose'][1].copy().astype(np.float32)
136
+ R1, t1 = T1[:3, :3], T1[:3, 3]
137
+ R, t = compute_relative_pose(R0, t0, R1, t1)
138
+ T0_to_1 = np.concatenate((R,t[:,None]), axis=-1)
139
+ im_A_path = f"{data_root}/images/{pairs[0]}"
140
+ im_B_path = f"{data_root}/images/{pairs[1]}"
141
+
142
+ im_A = cv2.imread(im_A_path)
143
+ im_B = cv2.imread(im_B_path)
144
+
145
+ if method == 'dense':
146
+ kpts0, kpts1, conf = model_helper.match_dense(im_A, im_B, thr=0.01, resize=1600)
147
+ elif method == 'lightglue':
148
+ kpts0, kpts1, conf = model_helper.match_lg(im_A, im_B, thr=0.01, resize=1600)
149
+ elif method == 'sparse':
150
+ kpts0, kpts1, conf = model_helper.match(im_A, im_B, thr=0.01, resize=1600)
151
+ else:
152
+ raise ValueError(f"Invalid method {method}")
153
+
154
+ im_A = Image.open(im_A_path)
155
+ w0, h0 = im_A.size
156
+ im_B = Image.open(im_B_path)
157
+ w1, h1 = im_B.size
158
+
159
+ if scale_intrinsics:
160
+ scale0 = 840 / max(w0, h0)
161
+ scale1 = 840 / max(w1, h1)
162
+ w0, h0 = scale0 * w0, scale0 * h0
163
+ w1, h1 = scale1 * w1, scale1 * h1
164
+ K0, K1 = K0.copy(), K1.copy()
165
+ K0[:2] = K0[:2] * scale0
166
+ K1[:2] = K1[:2] * scale1
167
+
168
+ threshold = 0.5
169
+ if calibrated:
170
+ norm_threshold = threshold / (np.mean(np.abs(K0[:2, :2])) + np.mean(np.abs(K1[:2, :2])))
171
+ ret = estimate_pose(
172
+ kpts0,
173
+ kpts1,
174
+ K0,
175
+ K1,
176
+ norm_threshold,
177
+ conf=0.99999,
178
+ )
179
+ if ret is not None:
180
+ R_est, t_est, mask = ret
181
+ T0_to_1_est = np.concatenate((R_est, t_est), axis=-1) #
182
+ T0_to_1 = np.concatenate((R, t[:,None]), axis=-1)
183
+ e_t, e_R = compute_pose_error(T0_to_1_est, R, t)
184
+
185
+ epi_errs = compute_symmetrical_epipolar_errors(T0_to_1, kpts0, kpts1, K0, K1)
186
+ if scene_ind % plot_every_iter == 0 and plot:
187
+
188
+ if not os.path.exists(f'outputs/mega_view/{model_name}_{method}'):
189
+ os.mkdir(f'outputs/mega_view/{model_name}_{method}')
190
+ name = f'outputs/mega_view/{model_name}_{method}/{scene_name}_{idx}.png'
191
+ _make_evaluation_figure(im_A, im_B, kpts0, kpts1, epi_errs, e_t, e_R, path=name)
192
+ e_pose = max(e_t, e_R)
193
+
194
+ tot_e_t.append(e_t)
195
+ tot_e_R.append(e_R)
196
+ tot_e_pose.append(e_pose)
197
+ idx += 1
198
+
199
+ tot_e_pose = np.array(tot_e_pose)
200
+ auc = pose_auc(tot_e_pose, thresholds)
201
+ acc_5 = (tot_e_pose < 5).mean()
202
+ acc_10 = (tot_e_pose < 10).mean()
203
+ acc_15 = (tot_e_pose < 15).mean()
204
+ acc_20 = (tot_e_pose < 20).mean()
205
+ map_5 = acc_5
206
+ map_10 = np.mean([acc_5, acc_10])
207
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
208
+ print(f"{model_name} auc: {auc}")
209
+ return {
210
+ "auc_5": auc[0],
211
+ "auc_10": auc[1],
212
+ "auc_20": auc[2],
213
+ "map_5": map_5,
214
+ "map_10": map_10,
215
+ "map_20": map_20,
216
+ }
217
+
218
+
219
+
220
+ def parse_arguments():
221
+ parser = argparse.ArgumentParser(description="Testing script.")
222
+
223
+ parser.add_argument("--data_root", type=str, default="./data/megadepth_view", help="Path to the MegaDepth dataset.")
224
+
225
+ parser.add_argument("--weights", type=str, default="./weights/RDD-v2.pth", help="Path to the model checkpoint.")
226
+
227
+ parser.add_argument("--plot", action="store_true", help="Whether to plot the results.")
228
+
229
+ parser.add_argument("--method", type=str, default="sparse", help="Method for matching.")
230
+
231
+ return parser.parse_args()
232
+
233
+ if __name__ == "__main__":
234
+ args = parse_arguments()
235
+ if not os.path.exists('outputs'):
236
+ os.mkdir('outputs')
237
+
238
+ if not os.path.exists(f'outputs/mega_view'):
239
+ os.mkdir(f'outputs/mega_view')
240
+ model = build(weights=args.weights)
241
+ benchmark = MegaDepthPoseMNNBenchmark(data_root=args.data_root)
242
+ model.eval()
243
+ model_helper = RDD_helper(model)
244
+ with torch.no_grad():
245
+ method = args.method
246
+ out = benchmark.benchmark(model_helper, model_name='RDD', plot_every_iter=1, plot=args.plot, method=method)
247
+ with open(f'outputs/mega_view/RDD_{method}.txt', 'w') as f:
248
+ f.write(str(out))
249
+
250
+
imcui/third_party/rdd/benchmarks/utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from kornia.geometry.epipolar import numeric
4
+ from kornia.geometry.conversions import convert_points_to_homogeneous
5
+ import cv2
6
+
7
+ def pose_auc(errors, thresholds):
8
+ sort_idx = np.argsort(errors)
9
+ errors = np.array(errors.copy())[sort_idx]
10
+ recall = (np.arange(len(errors)) + 1) / len(errors)
11
+ errors = np.r_[0.0, errors]
12
+ recall = np.r_[0.0, recall]
13
+ aucs = []
14
+ for t in thresholds:
15
+ last_index = np.searchsorted(errors, t)
16
+ r = np.r_[recall[:last_index], recall[last_index - 1]]
17
+ e = np.r_[errors[:last_index], t]
18
+ aucs.append(np.trapz(r, x=e) / t)
19
+ return aucs
20
+
21
+ def angle_error_vec(v1, v2):
22
+ n = np.linalg.norm(v1) * np.linalg.norm(v2)
23
+ return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
24
+
25
+ def angle_error_mat(R1, R2):
26
+ cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
27
+ cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
28
+ return np.rad2deg(np.abs(np.arccos(cos)))
29
+
30
+ def symmetric_epipolar_distance(pts0, pts1, E, K0, K1):
31
+ """Squared symmetric epipolar distance.
32
+ This can be seen as a biased estimation of the reprojection error.
33
+ Args:
34
+ pts0 (torch.Tensor): [N, 2]
35
+ E (torch.Tensor): [3, 3]
36
+ """
37
+ pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
38
+ pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
39
+ pts0 = convert_points_to_homogeneous(pts0)
40
+ pts1 = convert_points_to_homogeneous(pts1)
41
+
42
+ Ep0 = pts0 @ E.T # [N, 3]
43
+ p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,]
44
+ Etp1 = pts1 @ E # [N, 3]
45
+
46
+ d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N
47
+ return d
48
+
49
+ def compute_symmetrical_epipolar_errors(T_0to1, pts0, pts1, K0, K1, device='cuda'):
50
+ """
51
+ Update:
52
+ data (dict):{"epi_errs": [M]}
53
+ """
54
+ pts0 = torch.tensor(pts0, device=device)
55
+ pts1 = torch.tensor(pts1, device=device)
56
+ K0 = torch.tensor(K0, device=device)
57
+ K1 = torch.tensor(K1, device=device)
58
+ T_0to1 = torch.tensor(T_0to1, device=device)
59
+ Tx = numeric.cross_product_matrix(T_0to1[:3, 3])
60
+ E_mat = Tx @ T_0to1[:3, :3]
61
+
62
+ epi_err = symmetric_epipolar_distance(pts0, pts1, E_mat, K0, K1)
63
+ return epi_err
64
+
65
+ def compute_pose_error(T_0to1, R, t):
66
+ R_gt = T_0to1[:3, :3]
67
+ t_gt = T_0to1[:3, 3]
68
+ error_t = angle_error_vec(t.squeeze(), t_gt)
69
+ error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
70
+ error_R = angle_error_mat(R, R_gt)
71
+ return error_t, error_R
72
+
73
+ def compute_relative_pose(R1, t1, R2, t2):
74
+ rots = R2 @ (R1.T)
75
+ trans = -rots @ t1 + t2
76
+ return rots, trans
77
+
78
+ def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
79
+ if len(kpts0) < 5:
80
+ return None
81
+ K0inv = np.linalg.inv(K0[:2,:2])
82
+ K1inv = np.linalg.inv(K1[:2,:2])
83
+
84
+ kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T
85
+ kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
86
+ E, mask = cv2.findEssentialMat(
87
+ kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf
88
+ )
89
+
90
+ ret = None
91
+ if E is not None:
92
+ best_num_inliers = 0
93
+
94
+ for _E in np.split(E, len(E) / 3):
95
+ n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
96
+ if n > best_num_inliers:
97
+ best_num_inliers = n
98
+ ret = (R, t, mask.ravel() > 0)
99
+ return ret
100
+
101
+ def dynamic_alpha(n_matches,
102
+ milestones=[0, 300, 1000, 2000],
103
+ alphas=[1.0, 0.8, 0.4, 0.2]):
104
+ if n_matches == 0:
105
+ return 1.0
106
+ ranges = list(zip(alphas, alphas[1:] + [None]))
107
+ loc = bisect.bisect_right(milestones, n_matches) - 1
108
+ _range = ranges[loc]
109
+ if _range[1] is None:
110
+ return _range[0]
111
+ return _range[1] + (milestones[loc + 1] - n_matches) / (
112
+ milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
imcui/third_party/rdd/configs/default.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ activation: relu
2
+ block_dims:
3
+ - 8
4
+ - 16
5
+ - 32
6
+ - 64
7
+ d_model: 256
8
+ detection_threshold: 0.1
9
+ device: cuda
10
+ dim_feedforward: 1024
11
+ dropout: 0.1
12
+ enc_n_points: 8
13
+ hidden_dim: 256
14
+ lr_backbone: 2.0e-05
15
+ nhead: 8
16
+ num_encoder_layers: 4
17
+ num_feature_levels: 5
18
+ top_k: 4096
19
+ train_detector: False