diff --git a/.gitignore b/.gitignore index 89ff17b05947b5e88393988962e419b47f6fbbe0..d648d23b215c1c1a5bcb37be87f9ca1697bdd520 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,7 @@ cmake-build-debug/ *.pyc flagged .ipynb_checkpoints -__pycache__ +**__pycache__** Untitled* experiments third_party/REKD diff --git a/README.md b/README.md index c38ecd53fb3ef0fd61dad9cf08aa5303079c8665..28d96db59a06b898223de6a30182b69458533a9b 100644 --- a/README.md +++ b/README.md @@ -44,8 +44,9 @@ The tool currently supports various popular image matching algorithms, namely: | Algorithm | Supported | Conference/Journal | Year | GitHub Link | |------------------|-----------|--------------------|------|-------------| -| DaD | ✅ | ARXIV | 2025 | [Link](https://github.com/Parskatt/dad) | | LiftFeat | ✅ | ICRA | 2025 | [Link](https://github.com/lyp-deeplearning/LiftFeat) | +| RDD | ✅ | CVPR | 2025 | [Link](https://github.com/xtcpete/rdd) | +| DaD | ✅ | ARXIV | 2025 | [Link](https://github.com/Parskatt/dad) | | MINIMA | ✅ | ARXIV | 2024 | [Link](https://github.com/LSXI7/MINIMA) | | XoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/OnderT/XoFTR) | | EfficientLoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/zju3dv/EfficientLoFTR) | diff --git a/config/config.yaml b/config/config.yaml index 035640b8c2aaf0d720b3bbbb71b28816d529523a..3f8ccc9d1d78920534bb270a162c2b8553100ab9 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -267,6 +267,27 @@ matcher_zoo: paper: https://arxiv.org/abs/2505.0342 project: null display: true + rdd(sparse): + matcher: NN-mutual + feature: rdd + dense: false + info: + name: RDD(sparse) #dispaly name + source: "CVPR 2025" + github: hhttps://github.com/xtcpete/rdd + paper: https://arxiv.org/abs/2505.08013 + project: https://xtcpete.github.io/rdd + display: true + rdd(dense): + matcher: rdd_dense + dense: true + info: + name: RDD(dense) #dispaly name + source: "CVPR 2025" + github: hhttps://github.com/xtcpete/rdd + paper: https://arxiv.org/abs/2505.08013 + project: https://xtcpete.github.io/rdd + display: true dedode: matcher: Dual-Softmax feature: dedode diff --git a/imcui/hloc/extract_features.py b/imcui/hloc/extract_features.py index 811dc73ea0c4aa6d2629d082d974734c34ab94d5..199584749281cdcd41be0f121f414146bf2187a5 100644 --- a/imcui/hloc/extract_features.py +++ b/imcui/hloc/extract_features.py @@ -225,6 +225,17 @@ confs = { "resize_max": 1600, }, }, + "rdd": { + "output": "feats-rdd-n5000-r1600", + "model": { + "name": "rdd", + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1600, + }, + }, "aliked-n16-rot": { "output": "feats-aliked-n16-rot", "model": { diff --git a/imcui/hloc/extractors/liftfeat.py b/imcui/hloc/extractors/liftfeat.py index 6bc7f62fdb63dde9b74bd6be05843d8b5c977f20..860fa3a4608ed3cb46461b44b9f3ec83d6812747 100644 --- a/imcui/hloc/extractors/liftfeat.py +++ b/imcui/hloc/extractors/liftfeat.py @@ -1,13 +1,10 @@ -import logging import sys from pathlib import Path -import torch -import random from ..utils.base_model import BaseModel from .. import logger, MODEL_REPO_ID -fire_path = Path(__file__).parent / "../../third_party/LiftFeat" -sys.path.append(str(fire_path)) +liftfeat_path = Path(__file__).parent / "../../third_party/LiftFeat" +sys.path.append(str(liftfeat_path)) from models.liftfeat_wrapper import LiftFeat @@ -25,9 +22,7 @@ class Liftfeat(BaseModel): logger.info("Loading LiftFeat model...") model_path = self._download_model( repo_id=MODEL_REPO_ID, - filename="{}/{}".format( - Path(__file__).stem, self.conf["model_name"] - ), + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), ) self.net = LiftFeat( weight=model_path, diff --git a/imcui/hloc/extractors/rdd.py b/imcui/hloc/extractors/rdd.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a5df4848fb0e785c78ea2019cd00d5c5746b7 --- /dev/null +++ b/imcui/hloc/extractors/rdd.py @@ -0,0 +1,56 @@ +import sys +import yaml +from pathlib import Path +from ..utils.base_model import BaseModel +from .. import logger, MODEL_REPO_ID, DEVICE + +rdd_path = Path(__file__).parent / "../../third_party/rdd" +sys.path.append(str(rdd_path)) + +from RDD.RDD import build as build_rdd + +class Rdd(BaseModel): + default_conf = { + "keypoint_threshold": 0.1, + "max_keypoints": 4096, + "model_name": "RDD-v2.pth", + } + + required_inputs = ["image"] + + def _init(self, conf): + logger.info("Loading RDD model...") + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + Path(__file__).stem, self.conf["model_name"] + ), + ) + config_path = rdd_path / "configs/default.yaml" + with open(config_path, "r") as file: + config = yaml.safe_load(file) + config["top_k"] = conf["max_keypoints"] + config["detection_threshold"] = conf["keypoint_threshold"] + config["device"] = DEVICE + self.net = build_rdd(config=config, weights=model_path) + self.net.eval() + logger.info("Loading RDD model done!") + + def _forward(self, data): + image = data["image"] + pred = self.net.extract(image)[0] + keypoints = pred["keypoints"] + descriptors = pred["descriptors"] + scores = pred["scores"] + if self.conf["max_keypoints"] < len(keypoints): + idxs = scores.argsort()[-self.conf["max_keypoints"] or None :] + keypoints = keypoints[idxs, :2] + descriptors = descriptors[idxs] + scores = scores[idxs] + + pred = { + "keypoints": keypoints[None], + "descriptors": descriptors[None].permute(0, 2, 1), + "scores": scores[None], + } + return pred diff --git a/imcui/hloc/match_dense.py b/imcui/hloc/match_dense.py index c03ace218d4fd79b200a1dcfed99bddb83b00892..90063cf2db5f70d6d0b45e857c1f068cad9b3696 100644 --- a/imcui/hloc/match_dense.py +++ b/imcui/hloc/match_dense.py @@ -337,6 +337,23 @@ confs = { "dfactor": 8, }, }, + "rdd_dense": { + "output": "matches-rdd_dense", + "model": { + "name": "rdd_dense", + "model_name": "RDD-v2.pth", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": False, + "force_resize": True, + "resize_max": 1024, + "width": 320, + "height": 240, + "dfactor": 8, + }, + }, "minima_roma": { "output": "matches-minima_roma", "model": { diff --git a/imcui/hloc/matchers/rdd_dense.py b/imcui/hloc/matchers/rdd_dense.py new file mode 100644 index 0000000000000000000000000000000000000000..9dd7a268b68f2ba74c746b859f9bf11c2c21b373 --- /dev/null +++ b/imcui/hloc/matchers/rdd_dense.py @@ -0,0 +1,52 @@ +import sys +import yaml +import torch +from pathlib import Path +from ..utils.base_model import BaseModel +from .. import logger, MODEL_REPO_ID, DEVICE + +rdd_path = Path(__file__).parent / "../../third_party/rdd" +sys.path.append(str(rdd_path)) + +from RDD.RDD import build as build_rdd +from RDD.RDD_helper import RDD_helper + +class RddDense(BaseModel): + default_conf = { + "keypoint_threshold": 0.1, + "max_keypoints": 4096, + "model_name": "RDD-v2.pth", + "match_threshold": 0.1, + } + + required_inputs = ["image0", "image1"] + + def _init(self, conf): + logger.info("Loading RDD model...") + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format( + "rdd", self.conf["model_name"] + ), + ) + config_path = rdd_path / "configs/default.yaml" + with open(config_path, "r") as file: + config = yaml.safe_load(file) + config["top_k"] = conf["max_keypoints"] + config["detection_threshold"] = conf["keypoint_threshold"] + config["device"] = DEVICE + rdd_net = build_rdd(config=config, weights=model_path) + rdd_net.eval() + self.net = RDD_helper(rdd_net) + logger.info("Loading RDD model done!") + + def _forward(self, data): + img0 = data["image0"] + img1 = data["image1"] + mkpts_0, mkpts_1, conf = self.net.match_dense(img0, img1, thr=self.conf["match_threshold"]) + pred = { + "keypoints0": torch.from_numpy(mkpts_0), + "keypoints1": torch.from_numpy(mkpts_1), + "mconf": torch.from_numpy(conf), + } + return pred diff --git a/imcui/third_party/rdd/.gitignore b/imcui/third_party/rdd/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..494a2893ab1422a94de82e7d7075434e77bf9ca7 --- /dev/null +++ b/imcui/third_party/rdd/.gitignore @@ -0,0 +1,8 @@ +.venv +/build/ +**.egg-info +**.pyc +/.idea/ +**/__pycache__/ +weights/ +outputs \ No newline at end of file diff --git a/imcui/third_party/rdd/LICENSE b/imcui/third_party/rdd/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f49a4e16e68b128803cc2dcea614603632b04eac --- /dev/null +++ b/imcui/third_party/rdd/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/RDD.py b/imcui/third_party/rdd/RDD/RDD.py new file mode 100644 index 0000000000000000000000000000000000000000..411b54f884a7a0731aec6bfb6807d4fa847df12e --- /dev/null +++ b/imcui/third_party/rdd/RDD/RDD.py @@ -0,0 +1,260 @@ +# Description: RDD model +import torch +import torch.nn.functional as F +from torch import nn +import numpy as np +from .utils import NestedTensor, nested_tensor_from_tensor_list, to_pixel_coords, read_config +from .models.detector import build_detector +from .models.descriptor import build_descriptor +from .models.soft_detect import SoftDetect +from .models.interpolator import InterpolateSparse2d + +class RDD(nn.Module): + + def __init__(self, detector, descriptor, detection_threshold=0.5, top_k=4096, train_detector=False, device='cuda'): + super().__init__() + self.detector = detector + self.descriptor = descriptor + self.interpolator = InterpolateSparse2d('bicubic') + self.detection_threshold = detection_threshold + self.top_k = top_k + self.device = device + if train_detector: + for p in self.detector.parameters(): + p.requires_grad = True + for p in self.descriptor.parameters(): + p.requires_grad = False + else: + for p in self.detector.parameters(): + p.requires_grad = False + for p in self.descriptor.parameters(): + p.requires_grad = True + + self.softdetect = None + self.stride = descriptor.stride + + def train(self, mode=True): + super().train(mode) + self.set_softdetect(top_k=500, scores_th=0.2) + + def eval(self): + super().eval() + self.set_softdetect(top_k=self.top_k, scores_th=0.01) + + def forward(self, samples: NestedTensor): + + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + + scoremap = self.detector(samples) + + feats, matchibility = self.descriptor(samples) + + return feats, scoremap, matchibility + + def set_softdetect(self, top_k=4096, scores_th=0.01): + self.softdetect = SoftDetect(radius=2, top_k=top_k, scores_th=scores_th) + + @torch.inference_mode() + def filter(self, matchibility): + # Filter out keypoints on the border + B, _, H, W = matchibility.shape + frame = torch.zeros(B, H, W, device=matchibility.device) + frame[:, self.stride:-self.stride, self.stride:-self.stride] = 1 + matchibility = matchibility * frame + return matchibility + + @torch.inference_mode() + def extract(self, x): + if self.softdetect is None: + self.eval() + + x, rh1, rw1 = self.preprocess_tensor(x) + x = x.to(self.device).float() + B, _, _H1, _W1 = x.shape + M1, K1, H1 = self.forward(x) + M1 = F.normalize(M1, dim=1) + + keypoints, kptscores, scoredispersitys = self.softdetect(K1) + + keypoints = torch.vstack([keypoints[b].unsqueeze(0) for b in range(B)]) + kptscores = torch.vstack([kptscores[b].unsqueeze(0) for b in range(B)]) + + keypoints = to_pixel_coords(keypoints, _H1, _W1) + + feats = self.interpolator(M1, keypoints, H = _H1, W = _W1) + + feats = F.normalize(feats, dim=-1) + + # Correct kpt scale + keypoints = keypoints * torch.tensor([rw1,rh1], device=keypoints.device).view(1, -1) + valid = kptscores > self.detection_threshold + + return [ + {'keypoints': keypoints[b][valid[b]], + 'scores': kptscores[b][valid[b]], + 'descriptors': feats[b][valid[b]]} for b in range(B) + ] + + @torch.inference_mode() + def extract_3rd_party(self, x, model='aliked'): + """ + one image per batch + """ + x, rh1, rw1 = self.preprocess_tensor(x) + B, _, _H1, _W1 = x.shape + if model == 'aliked': + from third_party import extract_aliked_kpts + img = x + mkpts, scores = extract_aliked_kpts(img, self.device) + else: + raise ValueError('Unknown model') + + M1, _ = self.descriptor(x) + M1 = F.normalize(M1, dim=1) + + if mkpts.shape[1] > self.top_k: + idx = torch.argsort(scores, descending=True)[0][:self.top_k] + mkpts = mkpts[:,idx] + scores = scores[:,idx] + + feats = self.interpolator(M1, mkpts, H = _H1, W = _W1) + feats = F.normalize(feats, dim=-1) + mkpts = mkpts * torch.tensor([rw1,rh1], device=mkpts.device).view(1, 1, -1) + + return [ + {'keypoints': mkpts[b], + 'scores': scores[b], + 'descriptors': feats[b]} for b in range(B) + ] + + @torch.inference_mode() + def extract_dense(self, x, n_limit=30000, thr=0.01): + self.set_softdetect(top_k=n_limit, scores_th=-1) + + x, rh1, rw1 = self.preprocess_tensor(x) + + B, _, _H1, _W1 = x.shape + + M1, K1, H1 = self.forward(x) + M1 = F.normalize(M1, dim=1) + + keypoints, kptscores, scoredispersitys = self.softdetect(K1) + + keypoints = torch.vstack([keypoints[b].unsqueeze(0) for b in range(B)]) + kptscores = torch.vstack([kptscores[b].unsqueeze(0) for b in range(B)]) + + keypoints = to_pixel_coords(keypoints, _H1, _W1) + + feats = self.interpolator(M1, keypoints, H = _H1, W = _W1) + + feats = F.normalize(feats, dim=-1) + + H1 = self.filter(H1) + + dense_kpts, dense_scores, inds = self.sample_dense_kpts(H1, n_limit=n_limit) + + dense_keypoints = to_pixel_coords(dense_kpts, _H1, _W1) + + dense_feats = self.interpolator(M1, dense_keypoints, H = _H1, W = _W1) + + dense_feats = F.normalize(dense_feats, dim=-1) + + keypoints = keypoints * torch.tensor([rw1,rh1], device=keypoints.device).view(1, -1) + dense_keypoints = dense_keypoints * torch.tensor([rw1,rh1], device=dense_keypoints.device).view(1, -1) + + valid = kptscores > self.detection_threshold + valid_dense = dense_scores > thr + + return [ + {'keypoints': keypoints[b][valid[b]], + 'scores': kptscores[b][valid[b]], + 'descriptors': feats[b][valid[b]], + 'keypoints_dense': dense_keypoints[b][valid_dense[b]], + 'scores_dense': dense_scores[b][valid_dense[b]], + 'descriptors_dense': dense_feats[b][valid_dense[b]]} for b in range(B) + ] + + @torch.inference_mode() + def sample_dense_kpts(self, keypoint_logits, threshold=0.01, n_limit=30000, force_kpts = True): + + B, K, H, W = keypoint_logits.shape + + if n_limit < 0 or n_limit > H*W: + n_limit = min(H*W - 1, n_limit) + + scoremap = keypoint_logits.permute(0,2,3,1) + + scoremap = scoremap.reshape(B, H, W) + + frame = torch.zeros(B, H, W, device=keypoint_logits.device) + + frame[:, 1:-1, 1:-1] = 1 + + scoremap = scoremap * frame + + scoremap = scoremap.reshape(B, H*W) + + grid = self.get_grid(B, H, W, device = keypoint_logits.device) + + inds = torch.topk(scoremap, n_limit, dim=1).indices + + # inds = torch.multinomial(scoremap, top_k, replacement=False) + kpts = torch.gather(grid, 1, inds[..., None].expand(B, n_limit, 2)) + scoremap = torch.gather(scoremap, 1, inds) + if force_kpts: + valid = scoremap > threshold + kpts = kpts[valid][None] + scoremap = scoremap[valid][None] + + return kpts, scoremap, inds + + def preprocess_tensor(self, x): + """ Guarantee that image is divisible by 32 to avoid aliasing artifacts. """ + if isinstance(x, np.ndarray) and len(x.shape) == 3: + x = torch.tensor(x).permute(2,0,1)[None] + x = x.to(self.device).float() + + H, W = x.shape[-2:] + + _H, _W = (H//32) * 32, (W//32) * 32 + + rh, rw = H/_H, W/_W + + x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False) + return x, rh, rw + + @torch.inference_mode() + def get_grid(self, B, H, W, device = None): + x1_n = torch.meshgrid( + *[ + torch.linspace( + -1 + 1 / n, 1 - 1 / n, n, device=device + ) + for n in (B, H, W) + ] + ) + x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) + return x1_n + +def build(config=None, weights=None): + if config is None: + config = read_config('./configs/default.yaml') + if weights is not None: + config['weights'] = weights + device = torch.device(config['device']) + print('config', config) + detector = build_detector(config) + descriptor = build_descriptor(config) + model = RDD( + detector, + descriptor, + detection_threshold=config['detection_threshold'], + top_k=config['top_k'], + train_detector=config['train_detector'], + device=device + ) + if 'weights' in config and config['weights'] is not None: + model.load_state_dict(torch.load(config['weights'], map_location='cpu')) + model.to(device) + return model \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/RDD_helper.py b/imcui/third_party/rdd/RDD/RDD_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..7561ea3650e8fc641d23b4b43bf20209fefc5c91 --- /dev/null +++ b/imcui/third_party/rdd/RDD/RDD_helper.py @@ -0,0 +1,179 @@ +from .matchers import DualSoftmaxMatcher, DenseMatcher, LightGlue +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +import kornia + +class RDD_helper(nn.Module): + def __init__(self, RDD): + super().__init__() + self.matcher = DualSoftmaxMatcher(inv_temperature = 20, thr = 0.01) + self.dense_matcher = DenseMatcher(inv_temperature=20, thr=0.01) + self.RDD = RDD + self.lg_matcher = None + + @torch.inference_mode() + def match(self, img0, img1, thr=0.01, resize=None, top_k=4096): + if top_k is not None and top_k != self.RDD.top_k: + self.RDD.top_k = top_k + self.RDD.set_softdetect(top_k=top_k) + + img0, scale0 = self.parse_input(img0, resize) + img1, scale1 = self.parse_input(img1, resize) + + out0 = self.RDD.extract(img0)[0] + out1 = self.RDD.extract(img1)[0] + + # get top_k confident matches + mkpts0, mkpts1, conf = self.matcher(out0, out1, thr) + + scale0 = 1.0 / scale0 + scale1 = 1.0 / scale1 + + mkpts0 = mkpts0 * scale0 + mkpts1 = mkpts1 * scale1 + + return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy() + + @torch.inference_mode() + def match_lg(self, img0, img1, thr=0.01, resize=None, top_k=4096): + if self.lg_matcher is None: + lg_conf = { + "name": "lightglue", # just for interfacing + "input_dim": 256, # input descriptor dimension (autoselected from weights) + "descriptor_dim": 256, + "add_scale_ori": False, + "n_layers": 9, + "num_heads": 4, + "flash": True, # enable FlashAttention if available. + "mp": False, # enable mixed precision + "filter_threshold": 0.01, # match threshold + "depth_confidence": -1, # depth confidence threshold + "width_confidence": -1, # width confidence threshold + "weights": './weights/RDD_lg-v2.pth', # path to the weights + } + self.lg_matcher = LightGlue(features='rdd', conf=lg_conf).to(self.RDD.device) + + if top_k is not None and top_k != self.RDD.top_k: + self.RDD.top_k = top_k + self.RDD.set_softdetect(top_k=top_k) + + img0, scale0 = self.parse_input(img0, resize=resize) + img1, scale1 = self.parse_input(img1, resize=resize) + + size0 = torch.tensor(img0.shape[-2:])[None] + size1 = torch.tensor(img1.shape[-2:])[None] + + out0 = self.RDD.extract(img0)[0] + out1 = self.RDD.extract(img1)[0] + + # get top_k confident matches + image0_data = { + 'keypoints': out0['keypoints'][None], + 'descriptors': out0['descriptors'][None], + 'image_size': size0, + } + + image1_data = { + 'keypoints': out1['keypoints'][None], + 'descriptors': out1['descriptors'][None], + 'image_size': size1, + } + + pred = {} + + with torch.no_grad(): + pred.update({'image0': image0_data, 'image1': image1_data}) + pred.update(self.lg_matcher({**pred})) + + kpts0 = pred['image0']['keypoints'][0] + kpts1 = pred['image1']['keypoints'][0] + + matches = pred['matches'][0] + + mkpts0 = kpts0[matches[... , 0]] + mkpts1 = kpts1[matches[... , 1]] + conf = pred['scores'][0] + + valid_mask = conf > thr + mkpts0 = mkpts0[valid_mask] + mkpts1 = mkpts1[valid_mask] + conf = conf[valid_mask] + + scale0 = 1.0 / scale0 + scale1 = 1.0 / scale1 + mkpts0 = mkpts0 * scale0 + mkpts1 = mkpts1 * scale1 + + return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy() + + @torch.inference_mode() + def match_dense(self, img0, img1, thr=0.01, resize=None): + + img0, scale0 = self.parse_input(img0, resize=resize) + img1, scale1 = self.parse_input(img1, resize=resize) + + out0 = self.RDD.extract_dense(img0)[0] + out1 = self.RDD.extract_dense(img1)[0] + + # get top_k confident matches + mkpts0, mkpts1, conf = self.dense_matcher(out0, out1, thr, err_thr=self.RDD.stride) + + scale0 = 1.0 / scale0 + scale1 = 1.0 / scale1 + + mkpts0 = mkpts0 * scale0 + mkpts1 = mkpts1 * scale1 + + return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy() + + @torch.inference_mode() + def match_3rd_party(self, img0, img1, model='aliked', resize=None, thr=0.01): + img0, scale0 = self.parse_input(img0, resize=resize) + img1, scale1 = self.parse_input(img1, resize=resize) + + out0 = self.RDD.extract_3rd_party(img0, model=model)[0] + out1 = self.RDD.extract_3rd_party(img1, model=model)[0] + + mkpts0, mkpts1, conf = self.matcher(out0, out1, thr) + + scale0 = 1.0 / scale0 + scale1 = 1.0 / scale1 + + mkpts0 = mkpts0 * scale0 + mkpts1 = mkpts1 * scale1 + + return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy() + + def parse_input(self, x, resize=None): + if len(x.shape) == 3: + x = x[None, ...] + + if isinstance(x, np.ndarray): + x = torch.tensor(x).permute(0,3,1,2)/255 + + h, w = x.shape[-2:] + size = h, w + + if resize is not None: + size = self.get_new_image_size(h, w, resize) + x = kornia.geometry.transform.resize( + x, + size, + side='long', + antialias=True, + align_corners=None, + interpolation='bilinear', + ) + scale = torch.Tensor([x.shape[-1] / w, x.shape[-2] / h]).to(self.RDD.device) + + return x, scale + + def get_new_image_size(self, h, w, resize=1600): + aspect_ratio = w / h + size = int(resize / aspect_ratio), resize + + size = list(map(lambda x: int(x // 32 * 32), size)) # make sure size is divisible by 32 + + return size \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/dataset/__init__.py b/imcui/third_party/rdd/RDD/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/rdd/RDD/dataset/megadepth/__init__.py b/imcui/third_party/rdd/RDD/dataset/megadepth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c2bfaf13f81e0c5ad896ca3e3f9292a2e5c7f5 --- /dev/null +++ b/imcui/third_party/rdd/RDD/dataset/megadepth/__init__.py @@ -0,0 +1,2 @@ +from .megadepth import * +from .megadepth_warper import * diff --git a/imcui/third_party/rdd/RDD/dataset/megadepth/megadepth.py b/imcui/third_party/rdd/RDD/dataset/megadepth/megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..92be0b36ac7049d02eb3f8d6ee6e7f51a9ff3bf9 --- /dev/null +++ b/imcui/third_party/rdd/RDD/dataset/megadepth/megadepth.py @@ -0,0 +1,313 @@ +import os +import copy +import h5py +import torch +import pickle +import numpy as np +from PIL import Image +from tqdm import tqdm +from pathlib import Path +from torchvision import transforms +from torch.utils.data import Dataset + +import cv2 +from .utils import scale_intrinsics, warp_depth, warp_points2d + +class MegaDepthDataset(Dataset): + def __init__( + self, + root, + npz_path, + num_per_scene=100, + image_size=256, + min_overlap_score=0.1, + max_overlap_score=0.9, + gray=False, + crop_or_scale='scale', # crop, scale, crop_scale + train=True, + ): + self.data_path = Path(root) + self.num_per_scene = num_per_scene + self.train = train + self.image_size = image_size + self.gray = gray + self.crop_or_scale = crop_or_scale + + self.scene_info = dict(np.load(npz_path, allow_pickle=True)) + self.pair_infos = self.scene_info['pair_infos'].copy() + del self.scene_info['pair_infos'] + self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score and pair_info[1] < max_overlap_score] + if len(self.pair_infos) > num_per_scene: + indices = np.random.choice(len(self.pair_infos), num_per_scene, replace=False) + self.pair_infos = [self.pair_infos[idx] for idx in indices] + self.transforms = transforms.Compose([transforms.ToPILImage(), + transforms.ToTensor()]) + + def __len__(self): + return len(self.pair_infos) + + def recover_pair(self, idx): + (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx % len(self)] + + img_name1 = self.scene_info['image_paths'][idx0] + img_name2 = self.scene_info['image_paths'][idx1] + + depth1 = '/'.join([self.scene_info['depth_paths'][idx0].replace('phoenix/S6/zl548/MegaDepth_v1', 'depth_undistorted').split('/')[i] for i in [0, 1, -1]]) + depth2 = '/'.join([self.scene_info['depth_paths'][idx1].replace('phoenix/S6/zl548/MegaDepth_v1', 'depth_undistorted').split('/')[i] for i in [0, 1, -1]]) + + depth_path1 = self.data_path / depth1 + with h5py.File(depth_path1, 'r') as hdf5_file: + depth1 = np.array(hdf5_file['/depth']) + assert (np.min(depth1) >= 0) + image_path1 = self.data_path / img_name1 + image1 = Image.open(image_path1) + if image1.mode != 'RGB': + image1 = image1.convert('RGB') + image1 = np.array(image1) + assert (image1.shape[0] == depth1.shape[0] and image1.shape[1] == depth1.shape[1]) + intrinsics1 = self.scene_info['intrinsics'][idx0].copy() + pose1 = self.scene_info['poses'][idx0] + + depth_path2 = self.data_path / depth2 + with h5py.File(depth_path2, 'r') as hdf5_file: + depth2 = np.array(hdf5_file['/depth']) + assert (np.min(depth2) >= 0) + image_path2 = self.data_path / img_name2 + image2 = Image.open(image_path2) + if image2.mode != 'RGB': + image2 = image2.convert('RGB') + image2 = np.array(image2) + assert (image2.shape[0] == depth2.shape[0] and image2.shape[1] == depth2.shape[1]) + intrinsics2 = self.scene_info['intrinsics'][idx1].copy() + pose2 = self.scene_info['poses'][idx1] + + pose12 = pose2 @ np.linalg.inv(pose1) + pose21 = np.linalg.inv(pose12) + + if self.train: + if "crop" in self.crop_or_scale: + # ================================================= compute central_match + DOWNSAMPLE = 10 + # resize to speed up + depth1s = cv2.resize(depth1, (depth1.shape[1] // DOWNSAMPLE, depth1.shape[0] // DOWNSAMPLE)) + depth2s = cv2.resize(depth2, (depth2.shape[1] // DOWNSAMPLE, depth2.shape[0] // DOWNSAMPLE)) + intrinsic1s = scale_intrinsics(intrinsics1, (DOWNSAMPLE, DOWNSAMPLE)) + intrinsic2s = scale_intrinsics(intrinsics2, (DOWNSAMPLE, DOWNSAMPLE)) + + # warp + depth12s = warp_depth(depth1s, intrinsic1s, intrinsic2s, pose12, depth2s.shape) + depth21s = warp_depth(depth2s, intrinsic2s, intrinsic1s, pose21, depth1s.shape) + + depth12s[depth12s < 0] = 0 + depth21s[depth21s < 0] = 0 + + valid12s = np.logical_and(depth12s > 0, depth2s > 0) + valid21s = np.logical_and(depth21s > 0, depth1s > 0) + + pos1 = np.array(valid21s.nonzero()) + try: + idx1_random = np.random.choice(np.arange(pos1.shape[1]), 1) + uv1s = pos1[:, idx1_random][[1, 0]].reshape(1, 2) + d1s = np.array(depth1s[uv1s[0, 1], uv1s[0, 0]]).reshape(1, 1) + + uv12s, z12s = warp_points2d(uv1s, d1s, intrinsic1s, intrinsic2s, pose12) + + uv1 = uv1s[0] * DOWNSAMPLE + uv2 = uv12s[0] * DOWNSAMPLE + except ValueError: + uv1 = [depth1.shape[1] / 2, depth1.shape[0] / 2] + uv2 = [depth2.shape[1] / 2, depth2.shape[0] / 2] + + central_match = [uv1[1], uv1[0], uv2[1], uv2[0]] + # ================================================= compute central_match + + if self.crop_or_scale == 'crop': + # =============== padding + h1, w1, _ = image1.shape + h2, w2, _ = image2.shape + if h1 < self.image_size: + padding = np.zeros((self.image_size - h1, w1, 3)) + image1 = np.concatenate([image1, padding], axis=0).astype(np.uint8) + depth1 = np.concatenate([depth1, padding[:, :, 0]], axis=0).astype(np.float32) + h1, w1, _ = image1.shape + if w1 < self.image_size: + padding = np.zeros((h1, self.image_size - w1, 3)) + image1 = np.concatenate([image1, padding], axis=1).astype(np.uint8) + depth1 = np.concatenate([depth1, padding[:, :, 0]], axis=1).astype(np.float32) + if h2 < self.image_size: + padding = np.zeros((self.image_size - h2, w2, 3)) + image2 = np.concatenate([image2, padding], axis=0).astype(np.uint8) + depth2 = np.concatenate([depth2, padding[:, :, 0]], axis=0).astype(np.float32) + h2, w2, _ = image2.shape + if w2 < self.image_size: + padding = np.zeros((h2, self.image_size - w2, 3)) + image2 = np.concatenate([image2, padding], axis=1).astype(np.uint8) + depth2 = np.concatenate([depth2, padding[:, :, 0]], axis=1).astype(np.float32) + # =============== padding + image1, bbox1, image2, bbox2 = self.crop(image1, image2, central_match) + + depth1 = depth1[bbox1[0]: bbox1[0] + self.image_size, bbox1[1]: bbox1[1] + self.image_size] + depth2 = depth2[bbox2[0]: bbox2[0] + self.image_size, bbox2[1]: bbox2[1] + self.image_size] + elif self.crop_or_scale == 'scale': + image1, depth1, intrinsics1 = self.scale(image1, depth1, intrinsics1) + image2, depth2, intrinsics2 = self.scale(image2, depth2, intrinsics2) + bbox1 = bbox2 = np.array([0., 0.]) + elif self.crop_or_scale == 'crop_scale': + bbox1 = bbox2 = np.array([0., 0.]) + image1, depth1, intrinsics1 = self.crop_scale(image1, depth1, intrinsics1, central_match[:2]) + image2, depth2, intrinsics2 = self.crop_scale(image2, depth2, intrinsics2, central_match[2:]) + else: + raise RuntimeError(f"Unkown type {self.crop_or_scale}") + else: + bbox1 = bbox2 = np.array([0., 0.]) + + return (image1, depth1, intrinsics1, pose12, bbox1, + image2, depth2, intrinsics2, pose21, bbox2) + + def scale(self, image, depth, intrinsic): + img_size_org = image.shape + image = cv2.resize(image, (self.image_size, self.image_size)) + depth = cv2.resize(depth, (self.image_size, self.image_size)) + intrinsic = scale_intrinsics(intrinsic, (img_size_org[1] / self.image_size, img_size_org[0] / self.image_size)) + return image, depth, intrinsic + + def crop_scale(self, image, depth, intrinsic, centeral): + h_org, w_org, three = image.shape + image_size = min(h_org, w_org) + if h_org > w_org: + if centeral[1] - image_size // 2 < 0: + h_start = 0 + elif centeral[1] + image_size // 2 > h_org: + h_start = h_org - image_size + else: + h_start = int(centeral[1]) - image_size // 2 + w_start = 0 + else: + if centeral[0] - image_size // 2 < 0: + w_start = 0 + elif centeral[0] + image_size // 2 > w_org: + w_start = w_org - image_size + else: + w_start = int(centeral[0]) - image_size // 2 + h_start = 0 + + croped_image = image[h_start: h_start + image_size, w_start: w_start + image_size] + croped_depth = depth[h_start: h_start + image_size, w_start: w_start + image_size] + intrinsic[0, 2] = intrinsic[0, 2] - w_start + intrinsic[1, 2] = intrinsic[1, 2] - h_start + + image = cv2.resize(croped_image, (self.image_size, self.image_size)) + depth = cv2.resize(croped_depth, (self.image_size, self.image_size)) + intrinsic = scale_intrinsics(intrinsic, (image_size / self.image_size, image_size / self.image_size)) + + return image, depth, intrinsic + + def crop(self, image1, image2, central_match): + bbox1_i = max(int(central_match[0]) - self.image_size // 2, 0) + if bbox1_i + self.image_size >= image1.shape[0]: + bbox1_i = image1.shape[0] - self.image_size + bbox1_j = max(int(central_match[1]) - self.image_size // 2, 0) + if bbox1_j + self.image_size >= image1.shape[1]: + bbox1_j = image1.shape[1] - self.image_size + + bbox2_i = max(int(central_match[2]) - self.image_size // 2, 0) + if bbox2_i + self.image_size >= image2.shape[0]: + bbox2_i = image2.shape[0] - self.image_size + bbox2_j = max(int(central_match[3]) - self.image_size // 2, 0) + if bbox2_j + self.image_size >= image2.shape[1]: + bbox2_j = image2.shape[1] - self.image_size + + return (image1[bbox1_i: bbox1_i + self.image_size, bbox1_j: bbox1_j + self.image_size], + np.array([bbox1_i, bbox1_j]), + image2[bbox2_i: bbox2_i + self.image_size, bbox2_j: bbox2_j + self.image_size], + np.array([bbox2_i, bbox2_j]) + ) + + def __getitem__(self, idx): + (image1, depth1, intrinsics1, pose12, bbox1, + image2, depth2, intrinsics2, pose21, bbox2) \ + = self.recover_pair(idx) + + if self.gray: + gray1 = cv2.cvtColor(image1, cv2.COLOR_RGB2GRAY) + gray2 = cv2.cvtColor(image2, cv2.COLOR_RGB2GRAY) + gray1 = transforms.ToTensor()(gray1) + gray2 = transforms.ToTensor()(gray2) + if self.transforms is not None: + image1, image2 = self.transforms(image1), self.transforms(image2) # [C,H,W] + ret = {'image0': image1, + 'image1': image2, + 'angle': 0, + 'overlap': self.pair_infos[idx][1], + 'warp01_params': {'mode': 'se3', + 'width': self.image_size if self.train else image1.shape[2], + 'height': self.image_size if self.train else image1.shape[1], + 'pose01': torch.from_numpy(pose12.astype(np.float32)), + 'bbox0': torch.from_numpy(bbox1.astype(np.float32)), + 'bbox1': torch.from_numpy(bbox2.astype(np.float32)), + 'depth0': torch.from_numpy(depth1.astype(np.float32)), + 'depth1': torch.from_numpy(depth2.astype(np.float32)), + 'intrinsics0': torch.from_numpy(intrinsics1.astype(np.float32)), + 'intrinsics1': torch.from_numpy(intrinsics2.astype(np.float32))}, + 'warp10_params': {'mode': 'se3', + 'width': self.image_size if self.train else image2.shape[2], + 'height': self.image_size if self.train else image2.shape[2], + 'pose01': torch.from_numpy(pose21.astype(np.float32)), + 'bbox0': torch.from_numpy(bbox2.astype(np.float32)), + 'bbox1': torch.from_numpy(bbox1.astype(np.float32)), + 'depth0': torch.from_numpy(depth2.astype(np.float32)), + 'depth1': torch.from_numpy(depth1.astype(np.float32)), + 'intrinsics0': torch.from_numpy(intrinsics2.astype(np.float32)), + 'intrinsics1': torch.from_numpy(intrinsics1.astype(np.float32))}, + } + if self.gray: + ret['gray0'] = gray1 + ret['gray1'] = gray2 + return ret + + +if __name__ == '__main__': + from torch.utils.data import DataLoader + import matplotlib.pyplot as plt + + + def visualize(image0, image1, depth0, depth1): + # visualize image and depth + plt.figure(figsize=(9, 9)) + plt.subplot(2, 2, 1) + plt.imshow(image0, cmap='gray') + plt.subplot(2, 2, 2) + plt.imshow(depth0) + plt.subplot(2, 2, 3) + plt.imshow(image1, cmap='gray') + plt.subplot(2, 2, 4) + plt.imshow(depth1) + plt.show() + + + dataset = MegaDepthDataset( # root='../data/megadepth', + root='../data/imw2020val', + train=False, + using_cache=True, + pairs_per_scene=100, + image_size=256, + colorjit=True, + gray=False, + crop_or_scale='scale', + ) + dataset.build_dataset() + + batch_size = 2 + + loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0) + + for idx, batch in enumerate(tqdm(loader)): + image0, image1 = batch['image0'], batch['image1'] # [B,3,H,W] + depth0, depth1 = batch['warp01_params']['depth0'], batch['warp01_params']['depth1'] # [B,H,W] + intrinsics0, intrinsics1 = batch['warp01_params']['intrinsics0'], batch['warp01_params'][ + 'intrinsics1'] # [B,3,3] + + batch_size, channels, h, w = image0.shape + + for b_idx in range(batch_size): + visualize(image0[b_idx].permute(1, 2, 0), image1[b_idx].permute(1, 2, 0), depth0[b_idx], depth1[b_idx]) \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/dataset/megadepth/megadepth_warper.py b/imcui/third_party/rdd/RDD/dataset/megadepth/megadepth_warper.py new file mode 100644 index 0000000000000000000000000000000000000000..d34218ad89cada36b8467379997e46da26c6e0ef --- /dev/null +++ b/imcui/third_party/rdd/RDD/dataset/megadepth/megadepth_warper.py @@ -0,0 +1,75 @@ +import torch +from kornia.utils import create_meshgrid +import matplotlib.pyplot as plt +import pdb +from .utils import warp + +@torch.no_grad() +def spvs_coarse(data, scale = 8): + N, _, H0, W0 = data['image0'].shape + _, _, H1, W1 = data['image1'].shape + device = data['image0'].device + corrs = [] + for idx in range(N): + warp01_params = {} + for k, v in data['warp01_params'].items(): + if isinstance(v[idx], torch.Tensor): + warp01_params[k] = v[idx].to(device) + else: + warp01_params[k] = v[idx] + warp10_params = {} + for k, v in data['warp10_params'].items(): + if isinstance(v[idx], torch.Tensor): + warp10_params[k] = v[idx].to(device) + else: + warp10_params[k] = v[idx] + + # create kpts + h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) + grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(h1*w1, 2) # [N, hw, 2] + + # normalize kpts + grid_pt1_c = grid_pt1_c * scale + + + + try: + grid_pt1_c_valid, grid_pt10_c, ids1, ids1_out = warp(grid_pt1_c, warp10_params) + grid_pt10_c_valid, grid_pt01_c, ids0, ids0_out = warp(grid_pt10_c, warp01_params) + + # check reproj error + grid_pt1_c_valid = grid_pt1_c_valid[ids0] + dist = torch.linalg.norm(grid_pt1_c_valid - grid_pt01_c, dim=-1) + + mask_mutual = (dist < 1.5) + + #get correspondences + pts = torch.cat([grid_pt10_c_valid[mask_mutual] / scale, + grid_pt01_c[mask_mutual] / scale], dim=-1) + #remove repeated correspondences + lut_mat12 = torch.ones((h1, w1, 4), device = device, dtype = torch.float32) * -1 + lut_mat21 = torch.clone(lut_mat12) + src_pts = pts[:, :2] + tgt_pts = pts[:, 2:] + + lut_mat12[src_pts[:,1].long(), src_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1) + mask_valid12 = torch.all(lut_mat12 >= 0, dim=-1) + points = lut_mat12[mask_valid12] + + #Target-src check + src_pts, tgt_pts = points[:, :2], points[:, 2:] + lut_mat21[tgt_pts[:,1].long(), tgt_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1) + mask_valid21 = torch.all(lut_mat21 >= 0, dim=-1) + points = lut_mat21[mask_valid21] + + corrs.append(points) + except: + corrs.append(torch.zeros((0, 4), device = device)) + #pdb.set_trace() + #print('..') + + #Plot for debug purposes + # for i in range(len(corrs)): + # plot_corrs(data['image0'][i], data['image1'][i], corrs[i][:, :2]*8, corrs[i][:, 2:]*8) + + return corrs \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/dataset/megadepth/utils.py b/imcui/third_party/rdd/RDD/dataset/megadepth/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e91687b87bb27bc8216e1294668d9d547e25e7cd --- /dev/null +++ b/imcui/third_party/rdd/RDD/dataset/megadepth/utils.py @@ -0,0 +1,848 @@ +""" + "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024." + https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/ + + MegaDepth data handling was adapted from + LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py +""" + +import io +import cv2 +import numpy as np +import h5py +import torch +from numpy.linalg import inv +from kornia.geometry.epipolar import essential_from_Rt +from kornia.geometry.epipolar import fundamental_from_essential + +import cv2 +import torch +import numpy as np +from numba import jit +from copy import deepcopy + +try: + from utils.project_depth_nn_cython_pkg import project_depth_nn_cython + + nn_cython = True +except: + print('\033[1;41;37mWarning: using python to project depth!!!\033[0m') + + nn_cython = False + + +class EmptyTensorError(Exception): + pass + + +def mutual_NN(cross_matrix, mode: str = 'min'): + """ + compute mutual nearest neighbor from a cross_matrix, non-differentiable function + :param cross_matrix: N0xN1 + :param mode: 'min': mutual minimum; 'max':mutual maximum + :return: index0,index1, Mx2 + """ + if mode == 'min': + nn0 = cross_matrix == cross_matrix.min(dim=1, keepdim=True)[0] + nn1 = cross_matrix == cross_matrix.min(dim=0, keepdim=True)[0] + elif mode == 'max': + nn0 = cross_matrix == cross_matrix.max(dim=1, keepdim=True)[0] + nn1 = cross_matrix == cross_matrix.max(dim=0, keepdim=True)[0] + else: + raise TypeError("error mode, must be 'min' or 'max'.") + + mutual_nn = nn0 * nn1 + + return torch.nonzero(mutual_nn, as_tuple=False) + + +def mutual_argmax(value, mask=None, as_tuple=True): + """ + Args: + value: MxN + mask: MxN + + Returns: + + """ + value = value - value.min() # convert to non-negative tensor + if mask is not None: + value = value * mask + + max0 = value.max(dim=1, keepdim=True) # the col index the max value in each row + max1 = value.max(dim=0, keepdim=True) + + valid_max0 = value == max0[0] + valid_max1 = value == max1[0] + + mutual = valid_max0 * valid_max1 + if mask is not None: + mutual = mutual * mask + + return mutual.nonzero(as_tuple=as_tuple) + + +def mutual_argmin(value, mask=None): + return mutual_argmax(-value, mask) + + +def compute_keypoints_distance(kpts0, kpts1, p=2): + """ + Args: + kpts0: torch.tensor [M,2] + kpts1: torch.tensor [N,2] + p: (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm + + Returns: + dist, torch.tensor [N,M] + """ + dist = kpts0[:, None, :] - kpts1[None, :, :] # [M,N,2] + dist = torch.norm(dist, p=p, dim=2) # [M,N] + return dist + + +def keypoints_normal2pixel(kpts_normal, w, h): + wh = kpts_normal[0].new_tensor([[w - 1, h - 1]]) + kpts_pixel = [(kpts + 1) / 2 * wh for kpts in kpts_normal] + return kpts_pixel + + +def plot_keypoints(image, kpts, radius=2, color=(255, 0, 0)): + image = image.cpu().detach().numpy() if isinstance(image, torch.Tensor) else image + kpts = kpts.cpu().detach().numpy() if isinstance(kpts, torch.Tensor) else kpts + + if image.dtype is not np.dtype('uint8'): + image = image * 255 + image = image.astype(np.uint8) + + if len(image.shape) == 2 or image.shape[2] == 1: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + + out = np.ascontiguousarray(deepcopy(image)) + kpts = np.round(kpts).astype(int) + + for kpt in kpts: + y0, x0 = kpt + cv2.drawMarker(out, (x0, y0), color, cv2.MARKER_CROSS, radius) + + # cv2.circle(out, (x0, y0), radius, color, -1, lineType=cv2.LINE_4) + return out + + +def plot_matches(image0, image1, kpts0, kpts1, radius=2, color=(255, 0, 0), mcolor=(0, 255, 0), layout='lr'): + image0 = image0.cpu().detach().numpy() if isinstance(image0, torch.Tensor) else image0 + image1 = image1.cpu().detach().numpy() if isinstance(image1, torch.Tensor) else image1 + kpts0 = kpts0.cpu().detach().numpy() if isinstance(kpts0, torch.Tensor) else kpts0 + kpts1 = kpts1.cpu().detach().numpy() if isinstance(kpts1, torch.Tensor) else kpts1 + + out0 = plot_keypoints(image0, kpts0, radius, color) + out1 = plot_keypoints(image1, kpts1, radius, color) + + H0, W0 = image0.shape[0], image0.shape[1] + H1, W1 = image1.shape[0], image1.shape[1] + + if layout == "lr": + H, W = max(H0, H1), W0 + W1 + out = 255 * np.ones((H, W, 3), np.uint8) + out[:H0, :W0, :] = out0 + out[:H1, W0:, :] = out1 + elif layout == "ud": + H, W = H0 + H1, max(W0, W1) + out = 255 * np.ones((H, W, 3), np.uint8) + out[:H0, :W0, :] = out0 + out[H0:, :W1, :] = out1 + else: + raise ValueError("The layout must be 'lr' or 'ud'!") + + kpts0 = np.round(kpts0).astype(int) + kpts1 = np.round(kpts1).astype(int) + + for kpt0, kpt1 in zip(kpts0, kpts1): + (y0, x0), (y1, x1) = kpt0, kpt1 + + if layout == "lr": + cv2.line(out, (x0, y0), (x1 + W0, y1), color=mcolor, thickness=1, lineType=cv2.LINE_AA) + elif layout == "ud": + cv2.line(out, (x0, y0), (x1, y1 + H0), color=mcolor, thickness=1, lineType=cv2.LINE_AA) + + return out + + +def interpolate_depth(pos, depth): + pos = pos.t()[[1, 0]] # Nx2 -> 2xN; w,h -> h,w(i,j) + + # =============================================== from d2-net + device = pos.device + + ids = torch.arange(0, pos.size(1), device=device) + + h, w = depth.size() + + i = pos[0, :].detach() # TODO: changed here + j = pos[1, :].detach() # TODO: changed here + + # Valid corners + i_top_left = torch.floor(i).long() + j_top_left = torch.floor(j).long() + valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0) + + i_top_right = torch.floor(i).long() + j_top_right = torch.ceil(j).long() + valid_top_right = torch.min(i_top_right >= 0, j_top_right < w) + + i_bottom_left = torch.ceil(i).long() + j_bottom_left = torch.floor(j).long() + valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0) + + i_bottom_right = torch.ceil(i).long() + j_bottom_right = torch.ceil(j).long() + valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w) + + valid_corners = torch.min(torch.min(valid_top_left, valid_top_right), + torch.min(valid_bottom_left, valid_bottom_right)) + + i_top_left = i_top_left[valid_corners] + j_top_left = j_top_left[valid_corners] + + i_top_right = i_top_right[valid_corners] + j_top_right = j_top_right[valid_corners] + + i_bottom_left = i_bottom_left[valid_corners] + j_bottom_left = j_bottom_left[valid_corners] + + i_bottom_right = i_bottom_right[valid_corners] + j_bottom_right = j_bottom_right[valid_corners] + + ids = ids[valid_corners] + ids_valid_corners = deepcopy(ids) + if ids.size(0) == 0: + # raise ValueError('empty tensor: ids') + raise EmptyTensorError + + # Valid depth + valid_depth = torch.min(torch.min(depth[i_top_left, j_top_left] > 0, + depth[i_top_right, j_top_right] > 0), + torch.min(depth[i_bottom_left, j_bottom_left] > 0, + depth[i_bottom_right, j_bottom_right] > 0)) + + i_top_left = i_top_left[valid_depth] + j_top_left = j_top_left[valid_depth] + + i_top_right = i_top_right[valid_depth] + j_top_right = j_top_right[valid_depth] + + i_bottom_left = i_bottom_left[valid_depth] + j_bottom_left = j_bottom_left[valid_depth] + + i_bottom_right = i_bottom_right[valid_depth] + j_bottom_right = j_bottom_right[valid_depth] + + ids = ids[valid_depth] + ids_valid_depth = deepcopy(ids) + if ids.size(0) == 0: + # raise ValueError('empty tensor: ids') + raise EmptyTensorError + + # Interpolation + i = i[ids] + j = j[ids] + dist_i_top_left = i - i_top_left.float() + dist_j_top_left = j - j_top_left.float() + w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) + w_top_right = (1 - dist_i_top_left) * dist_j_top_left + w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) + w_bottom_right = dist_i_top_left * dist_j_top_left + + interpolated_depth = (w_top_left * depth[i_top_left, j_top_left] + + w_top_right * depth[i_top_right, j_top_right] + + w_bottom_left * depth[i_bottom_left, j_bottom_left] + + w_bottom_right * depth[i_bottom_right, j_bottom_right]) + + # pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0) + + pos = pos[:, ids] + + # =============================================== from d2-net + pos = pos[[1, 0]].t() # 2xN -> Nx2; h,w(i,j) -> w,h + + # interpolated_depth: valid interpolated depth + # pos: valid position (keypoint) + # ids: indices of valid position (keypoint) + + return [interpolated_depth, pos, ids, ids_valid_corners, ids_valid_depth] + + +def to_homogeneous(kpts): + ''' + :param kpts: Nx2 + :return: Nx3 + ''' + ones = kpts.new_ones([kpts.shape[0], 1]) + return torch.cat((kpts, ones), dim=1) + + +def warp_homography(kpts0, params): + ''' + :param kpts: Nx2 + :param homography_matrix: 3x3 + :return: + ''' + homography_matrix = params['homography_matrix'] + w, h = params['width'], params['height'] + kpts0_homogeneous = to_homogeneous(kpts0) + kpts01_homogeneous = torch.einsum('ij,kj->ki', homography_matrix, kpts0_homogeneous) + kpts01 = kpts01_homogeneous[:, :2] / kpts01_homogeneous[:, 2:] + + kpts01_ = kpts01.detach() + # due to float coordinates, the upper boundary should be (w-1) and (h-1). + # For example, if the image size is 480, then the coordinates should in [0~470]. + # 470.5 is not acceptable. + valid01 = (kpts01_[:, 0] >= 0) * (kpts01_[:, 0] <= w - 1) * (kpts01_[:, 1] >= 0) * (kpts01_[:, 1] <= h - 1) + kpts0_valid = kpts0[valid01] + kpts01_valid = kpts01[valid01] + ids = torch.nonzero(valid01, as_tuple=False)[:, 0] + ids_out = torch.nonzero(~valid01, as_tuple=False)[:, 0] + + # kpts0_valid: valid keypoints0, the invalid and inconsistance keypoints are removed + # kpts01_valid: the warped valid keypoints0 + # ids: the valid indices + return kpts0_valid, kpts01_valid, ids, ids_out + + +def project(points3d, K): + """ + project 3D points to image plane + + Args: + points3d: [N,3] + K: [3,3] + + Returns: + uv, (u,v), [N,2] + """ + if type(K) == torch.Tensor: + zuv1 = torch.einsum('jk,nk->nj', K, points3d) # z*(u,v,1) = K*points3d -> [N,3] + elif type(K) == np.ndarray: + zuv1 = np.einsum('jk,nk->nj', K, points3d) + else: + raise TypeError("Input type should be 'torch.tensor' or 'numpy.ndarray'") + uv1 = zuv1 / zuv1[:, -1][:, None] # (u,v,1) -> [N,3] + uv = uv1[:, 0:2] # (u,v) -> [N,2] + return uv, zuv1[:, -1] + + +def unproject(uv, d, K): + """ + unproject pixels uv to 3D points + + Args: + uv: [N,2] + d: depth, [N,1] + K: [3,3] + + Returns: + 3D points, [N,3] + """ + duv = uv * d # (u,v) [N,2] + if type(K) == torch.Tensor: + duv1 = torch.cat([duv, d], dim=1) # z*(u,v,1) [N,3] + K_inv = torch.inverse(K) # [3,3] + points3d = torch.einsum('jk,nk->nj', K_inv, duv1) # [N,3] + elif type(K) == np.ndarray: + duv1 = np.concatenate((duv, d), axis=1) # z*(u,v,1) [N,3] + K_inv = np.linalg.inv(K) # [3,3] + points3d = np.einsum('jk,nk->nj', K_inv, duv1) # [N,3] + else: + raise TypeError("Input type should be 'torch.tensor' or 'numpy.ndarray'") + return points3d + + +def warp_se3(kpts0, params): + pose01 = params['pose01'] # relative motion + bbox0 = params['bbox0'] # row, col + bbox1 = params['bbox1'] + depth0 = params['depth0'] + depth1 = params['depth1'] + intrinsics0 = params['intrinsics0'] + intrinsics1 = params['intrinsics1'] + + # kpts0_valid: valid kpts0 + # z0_valid: depth of valid kpts0 + # ids0: the indices of valid kpts0 ( valid corners and valid depth) + # ids0_valid_corners: the valid indices of kpts0 in image ( 0<=x 0 ) + z0_valid, kpts0_valid, ids0, ids0_valid_corners, ids0_valid_depth = interpolate_depth(kpts0, depth0) + + # COLMAP convention + bkpts0_valid = kpts0_valid + bbox0[[1, 0]][None, :] + 0.5 + + # unproject pixel coordinate to 3D points (camera coordinate system) + bpoints3d0 = unproject(bkpts0_valid, z0_valid.unsqueeze(1), intrinsics0) # [:,3] + bpoints3d0_homo = to_homogeneous(bpoints3d0) # [:,4] + + # warp 3D point (camera 0 coordinate system) to 3D point (camera 1 coordinate system) + bpoints3d01_homo = torch.einsum('jk,nk->nj', pose01, bpoints3d0_homo) # [:,4] + bpoints3d01 = bpoints3d01_homo[:, 0:3] # [:,3] + + # project 3D point (camera coordinate system) to pixel coordinate + buv01, z01 = project(bpoints3d01, intrinsics1) # uv: [:,2], (h,w); z1: [N] + + uv01 = buv01 - bbox1[None, [1, 0]] - .5 + + # kpts01_valid: valid kpts01 + # z01_valid: depth of valid kpts01 + # ids01: the indices of valid kpts01 ( valid corners and valid depth) + # ids01_valid_corners: the valid indices of kpts01 in image ( 0<=x 0 ) + z01_interpolate, kpts01_valid, ids01, ids01_valid_corners, ids01_valid_depth = interpolate_depth(uv01, depth1) + + outimage_mask = torch.ones(ids0.shape[0], device=ids0.device).bool() + outimage_mask[ids01_valid_corners] = 0 + ids01_invalid_corners = torch.arange(0, ids0.shape[0], device=ids0.device)[outimage_mask] + ids_outside = ids0[ids01_invalid_corners] + + # ids_valid: matched kpts01 without occlusion + ids_valid = ids0[ids01] + kpts0_valid = kpts0_valid[ids01] + z01_proj = z01[ids01] + + inlier_mask = torch.abs(z01_proj - z01_interpolate) < 0.05 + + # indices of kpts01 with occlusion + ids_occlude = ids_valid[~inlier_mask] + + ids_valid = ids_valid[inlier_mask] + if ids_valid.size(0) == 0: + # raise ValueError('empty tensor: ids') + raise EmptyTensorError + + kpts01_valid = kpts01_valid[inlier_mask] + kpts0_valid = kpts0_valid[inlier_mask] + + # indices of kpts01 which are no matches in image1 for sure, + # other projected kpts01 are not sure because of no depth in image0 or imgae1 + ids_out = torch.cat([ids_outside, ids_occlude]) + + # kpts0_valid: valid keypoints0, the invalid and inconsistance keypoints are removed + # kpts01_valid: the warped valid keypoints0 + # ids: the valid indices + return kpts0_valid, kpts01_valid, ids_valid, ids_out + + +def warp(kpts0, params: dict): + mode = params['mode'] + if mode == 'homo': + return warp_homography(kpts0, params) + elif mode == 'se3': + return warp_se3(kpts0, params) + else: + raise ValueError('unknown mode!') + + +def warp_xy(kpts0_xy, params: dict): + w, h = params['width'], params['height'] + kpts0 = (kpts0_xy / 2 + 0.5) * kpts0_xy.new_tensor([[w - 1, h - 1]]) + kpts0, kpts01, ids = warp(kpts0, params) + kpts01_xy = 2 * kpts01 / kpts01.new_tensor([[w - 1, h - 1]]) - 1 + kpts0_xy = 2 * kpts0 / kpts0.new_tensor([[w - 1, h - 1]]) - 1 + return kpts0_xy, kpts01_xy, ids + + +def scale_intrinsics(K, scales): + scales = np.diag([1. / scales[0], 1. / scales[1], 1.]) + return np.dot(scales, K) + + +def warp_points3d(points3d0, pose01): + points3d0_homo = np.concatenate((points3d0, np.ones(points3d0.shape[0])[:, np.newaxis]), axis=1) # [:,4] + + points3d01_homo = np.einsum('jk,nk->nj', pose01, points3d0_homo) # [N,4] + points3d01 = points3d01_homo[:, 0:3] # [N,3] + + return points3d01 + + +def unproject_depth(depth, K): + h, w = depth.shape + + wh_range = np.mgrid[0:w, 0:h].transpose(2, 1, 0) # [H,W,2] + + uv = wh_range.reshape(-1, 2) + d = depth.reshape(-1, 1) + points3d = unproject(uv, d, K) + + valid = np.logical_and((d[:, 0] > 0), (points3d[:, 2] > 0)) + + return points3d, valid + + +@jit(nopython=True) +def project_depth_nn_python(uv, z, depth): + h, w = depth.shape + # TODO: speed up the for loop + for idx in range(len(uv)): + uvi = uv[idx] + x = int(round(uvi[0])) + y = int(round(uvi[1])) + + if x < 0 or y < 0 or x >= w or y >= h: + continue + + if depth[y, x] == 0. or depth[y, x] > z[idx]: + depth[y, x] = z[idx] + return depth + + +def project_nn(uv, z, depth): + """ + uv: pixel coordinates [N,2] + z: projected depth (xyz -> z) [N] + depth: output depth array: [h,w] + """ + if nn_cython: + return project_depth_nn_cython(uv.astype(np.float64), + z.astype(np.float64), + depth.astype(np.float64)) + else: + return project_depth_nn_python(uv, z, depth) + + +def warp_depth(depth0, intrinsics0, intrinsics1, pose01, shape1): + points3d0, valid0 = unproject_depth(depth0, intrinsics0) # [:,3] + points3d0 = points3d0[valid0] + + points3d01 = warp_points3d(points3d0, pose01) + + uv01, z01 = project(points3d01, intrinsics1) # uv: [N,2], (h,w); z1: [N] + + depth01 = project_nn(uv01, z01, depth=np.zeros(shape=shape1)) + + return depth01 + + +def warp_points2d(uv0, d0, intrinsics0, intrinsics1, pose01): + points3d0 = unproject(uv0, d0, intrinsics0) + points3d01 = warp_points3d(points3d0, pose01) + uv01, z01 = project(points3d01, intrinsics1) + return uv01, z01 + + +def display_image_in_actual_size(image): + import matplotlib.pyplot as plt + + dpi = 100 + height, width = image.shape[:2] + + # What size does the figure need to be in inches to fit the image? + figsize = width / float(dpi), height / float(dpi) + + # Create a figure of the right size with one axes that takes up the full figure + fig = plt.figure(figsize=figsize) + ax = fig.add_axes([0, 0, 1, 1]) + + # Hide spines, ticks, etc. + ax.axis('off') + + # Display the image. + if len(image.shape) == 3: + ax.imshow(image, cmap='gray') + elif len(image.shape) == 2: + if image.dtype == np.uint8: + ax.imshow(image, cmap='gray') + else: + ax.imshow(image) + ax.text(20, 20, f"Range: {image.min():g}~{image.max():g}", color='red') + plt.show() + + +# ====================================== copied from ASLFeat +from datetime import datetime + + +class ClassProperty(property): + """For dynamically obtaining system time""" + + def __get__(self, cls, owner): + return classmethod(self.fget).__get__(None, owner)() + + +class Notify(object): + """Colorful printing prefix. + A quick example: + print(Notify.INFO, YOUR TEXT, Notify.ENDC) + """ + + def __init__(self): + pass + + @ClassProperty + def HEADER(cls): + return str(datetime.now()) + ': \033[95m' + + @ClassProperty + def INFO(cls): + return str(datetime.now()) + ': \033[92mI' + + @ClassProperty + def OKBLUE(cls): + return str(datetime.now()) + ': \033[94m' + + @ClassProperty + def WARNING(cls): + return str(datetime.now()) + ': \033[93mW' + + @ClassProperty + def FAIL(cls): + return str(datetime.now()) + ': \033[91mF' + + @ClassProperty + def BOLD(cls): + return str(datetime.now()) + ': \033[1mB' + + @ClassProperty + def UNDERLINE(cls): + return str(datetime.now()) + ': \033[4mU' + + ENDC = '\033[0m' + +def get_essential(T0, T1): + R0 = T0[:3, :3] + R1 = T1[:3, :3] + + t0 = T0[:3, 3].reshape(3, 1) + t1 = T1[:3, 3].reshape(3, 1) + + R0 = torch.tensor(R0, dtype=torch.float32) + R1 = torch.tensor(R1, dtype=torch.float32) + t0 = torch.tensor(t0, dtype=torch.float32) + t1 = torch.tensor(t1, dtype=torch.float32) + + E = essential_from_Rt(R0, t0, R1, t1) + + return E + +def get_fundamental(E, K0, K1): + F = fundamental_from_essential(E, K0, K1) + + return F +try: + # for internel use only + from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT +except Exception: + MEGADEPTH_CLIENT = SCANNET_CLIENT = None + +# --- DATA IO --- + +def load_array_from_s3( + path, client, cv_type, + use_h5py=False, +): + byte_str = client.Get(path) + try: + if not use_h5py: + raw_array = np.fromstring(byte_str, np.uint8) + data = cv2.imdecode(raw_array, cv_type) + else: + f = io.BytesIO(byte_str) + data = np.array(h5py.File(f, 'r')['/depth']) + except Exception as ex: + print(f"==> Data loading failure: {path}") + raise ex + + assert data is not None + return data + + +def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT): + cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \ + else cv2.IMREAD_COLOR + if str(path).startswith('s3://'): + image = load_array_from_s3(str(path), client, cv_type) + else: + image = cv2.imread(str(path), 1) + + if augment_fn is not None: + image = cv2.imread(str(path), cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = augment_fn(image) + image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + return image # (h, w) + + +def get_resized_wh(w, h, resize=None): + if resize is not None: # resize the longer edge + scale = resize / max(h, w) + w_new, h_new = int(round(w*scale)), int(round(h*scale)) + else: + w_new, h_new = w, h + return w_new, h_new + + +def get_divisible_wh(w, h, df=None): + if df is not None: + w_new, h_new = map(lambda x: int(x // df * df), [w, h]) + else: + w_new, h_new = w, h + return w_new, h_new + + +def pad_bottom_right(inp, pad_size, ret_mask=False): + assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" + mask = None + if inp.ndim == 2: + padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) + padded[:inp.shape[0], :inp.shape[1]] = inp + if ret_mask: + mask = np.zeros((pad_size, pad_size), dtype=bool) + mask[:inp.shape[0], :inp.shape[1]] = True + elif inp.ndim == 3: + padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) + padded[:, :inp.shape[1], :inp.shape[2]] = inp + if ret_mask: + mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) + mask[:, :inp.shape[1], :inp.shape[2]] = True + else: + raise NotImplementedError() + return padded, mask + + +# --- MEGADEPTH --- + +def fix_path_from_d2net(path): + if not path: + return None + + path = path.replace('Undistorted_SfM/', '') + path = path.replace('images', 'dense0/imgs') + path = path.replace('phoenix/S6/zl548/MegaDepth_v1/', '') + + return path + +def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None): + """ + Args: + resize (int, optional): the longer edge of resized images. None for no resize. + padding (bool): If set to 'True', zero-pad resized images to squared size. + augment_fn (callable, optional): augments images with pre-defined visual effects + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read image + image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT) + + # resize image + w, h = image.shape[1], image.shape[0] + + if len(resize) == 2: + w_new, h_new = resize + else: + resize = resize[0] + w_new, h_new = get_resized_wh(w, h, resize) + w_new, h_new = get_divisible_wh(w_new, h_new, df) + + + image = cv2.resize(image, (w_new, h_new)) + scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) + + if padding: # padding + pad_to = max(h_new, w_new) + image, mask = pad_bottom_right(image, pad_to, ret_mask=True) + else: + mask = None + + #image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized + image = torch.from_numpy(image).float().permute(2,0,1) / 255 # (h, w) -> (1, h, w) and normalized + mask = torch.from_numpy(mask) if mask is not None else None + + return image, mask, scale + +def imread_color(path, augment_fn=None, client=SCANNET_CLIENT): + cv_type = cv2.IMREAD_COLOR + # if str(path).startswith('s3://'): + # image = load_array_from_s3(str(path), client, cv_type) + # else: + # image = cv2.imread(str(path), cv_type) + + image = cv2.imread(str(path), cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + if augment_fn is not None: + image = augment_fn(image) + # image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + return image # (3, h, w) + + +def read_megadepth_color(path, + resize=None, + df=None, + padding=False, + augment_fn=None, + rotation=0): + """ + Args: + resize (int, optional): the longer edge of resized images. None for no resize. + padding (bool): If set to 'True', zero-pad resized images to squared size. + augment_fn (callable, optional): augments images with pre-defined visual effects + Returns: + image (torch.tensor): (3, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read image + image = imread_color(path, augment_fn, client=MEGADEPTH_CLIENT) + + if rotation != 0: + image = np.rot90(image, k=rotation).copy() + + # resize image + if resize is not None: + w, h = image.shape[1], image.shape[0] + if len(resize) == 2: + w_new, h_new = resize + else: + resize = resize[0] + w_new, h_new = get_resized_wh(w, h, resize) + w_new, h_new = get_divisible_wh(w_new, h_new, df) + + + image = cv2.resize(image, (w_new, h_new)) + scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) + scale_wh = torch.tensor([w_new, h_new], dtype=torch.float) + else: + scale = torch.tensor([1., 1.], dtype=torch.float) + scale_wh = torch.tensor([image.shape[1], image.shape[0]], dtype=torch.float) + + image = image.transpose(2, 0, 1) + + if padding: # padding + if resize is not None: + pad_to = max(h_new, w_new) + else: + pad_to = 2000 + image, mask = pad_bottom_right(image, pad_to, ret_mask=True) + else: + mask = None + + image = torch.from_numpy(image).float() / 255 # (h, w) -> (1, h, w) and normalized + mask = torch.from_numpy(mask) if mask is not None else None + + return image, mask, scale + +def read_megadepth_depth(path, pad_to=None): + + if str(path).startswith('s3://'): + depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True) + else: + depth = np.array(h5py.File(path, 'r')['depth']) + if pad_to is not None: + depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False) + depth = torch.from_numpy(depth).float() # (h, w) + return depth + +def get_image_name(path): + return path.split('/')[-1] + +def scale_intrinsics(K, scales): + scales = np.diag([1. / scales[0], 1. / scales[1], 1.]) + return np.dot(scales, K) \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/matchers/__init__.py b/imcui/third_party/rdd/RDD/matchers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..824848ef4ddf5ff92f1050799cb2be7bb4693622 --- /dev/null +++ b/imcui/third_party/rdd/RDD/matchers/__init__.py @@ -0,0 +1,3 @@ +from .dual_softmax_matcher import DualSoftmaxMatcher +from .dense_matcher import DenseMatcher +from .lightglue import LightGlue \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/matchers/dense_matcher.py b/imcui/third_party/rdd/RDD/matchers/dense_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..c862be1699c27ac92aa4f8fb0c3771815edf491c --- /dev/null +++ b/imcui/third_party/rdd/RDD/matchers/dense_matcher.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import poselib + +class DenseMatcher(nn.Module): + def __init__(self, inv_temperature = 20, thr = 0.01): + super().__init__() + self.inv_temperature = inv_temperature + self.thr = thr + + def forward(self, info0, info1, thr = None, err_thr=4, min_num_inliers=30): + + desc0 = info0['descriptors'] + desc1 = info1['descriptors'] + + inds, P = self.dual_softmax(desc0, desc1, thr=thr) + + mkpts_0 = info0['keypoints'][inds[:,0]] + mkpts_1 = info1['keypoints'][inds[:,1]] + mconf = P[inds[:,0], inds[:,1]] + Fm, inliers = self.get_fundamental_matrix(mkpts_0, mkpts_1) + + if inliers.sum() >= min_num_inliers: + desc1_dense = info0['descriptors_dense'] + desc2_dense = info1['descriptors_dense'] + + inds_dense, P_dense = self.dual_softmax(desc1_dense, desc2_dense, thr=thr) + + mkpts_0_dense = info0['keypoints_dense'][inds_dense[:,0]] + mkpts_1_dense = info1['keypoints_dense'][inds_dense[:,1]] + mconf_dense = P_dense[inds_dense[:,0], inds_dense[:,1]] + + mkpts_0_dense, mkpts_1_dense, mconf_dense = self.refine_matches(mkpts_0_dense, mkpts_1_dense, mconf_dense, Fm, err_thr=err_thr) + mkpts_0 = mkpts_0[inliers] + mkpts_1 = mkpts_1[inliers] + mconf = mconf[inliers] + # concatenate the matches + mkpts_0 = torch.cat([mkpts_0, mkpts_0_dense], dim=0) + mkpts_1 = torch.cat([mkpts_1, mkpts_1_dense], dim=0) + mconf = torch.cat([mconf, mconf_dense], dim=0) + + return mkpts_0, mkpts_1, mconf + + def get_fundamental_matrix(self, kpts_0, kpts_1): + Fm, info = poselib.estimate_fundamental(kpts_0.cpu().numpy(), kpts_1.cpu().numpy(), {'max_epipolar_error': 1, 'progressive_sampling': True}, {}) + inliers = info['inliers'] + Fm = torch.tensor(Fm, device=kpts_0.device, dtype=kpts_0.dtype) + inliers = torch.tensor(inliers, device=kpts_0.device, dtype=torch.bool) + return Fm, inliers + + def dual_softmax(self, desc0, desc1, thr = None): + if thr is None: + thr = self.thr + dist_mat = (desc0 @ desc1.t()) * self.inv_temperature + P = dist_mat.softmax(dim = -2) * dist_mat.softmax(dim= -1) + inds = torch.nonzero((P == P.max(dim=-1, keepdim = True).values) + * (P == P.max(dim=-2, keepdim = True).values) * (P >= thr)) + + return inds, P + + @torch.inference_mode() + def refine_matches(self, mkpts_0, mkpts_1, mconf, Fm, err_thr=4): + mkpts_0_h = torch.cat([mkpts_0, torch.ones(mkpts_0.shape[0], 1, device=mkpts_0.device)], dim=1) # (N, 3) + mkpts_1_h = torch.cat([mkpts_1, torch.ones(mkpts_1.shape[0], 1, device=mkpts_1.device)], dim=1) # (N, 3) + + lines_1 = torch.matmul(Fm, mkpts_0_h.T).T + + a, b, c = lines_1[:, 0], lines_1[:, 1], lines_1[:, 2] + + x1, y1 = mkpts_1[:, 0], mkpts_1[:, 1] + + denom = a**2 + b**2 + 1e-8 + + x_offset = (b * (b * x1 - a * y1) - a * c) / denom - x1 + y_offset = (a * (a * y1 - b * x1) - b * c) / denom - y1 + + inds = (x_offset.abs() < err_thr) | (y_offset.abs() < err_thr) + + x_offset = x_offset[inds] + y_offset = y_offset[inds] + + mkpts_0 = mkpts_0[inds] + mkpts_1 = mkpts_1[inds] + + refined_mkpts_1 = mkpts_1 + torch.stack([x_offset, y_offset], dim=1) + + return mkpts_0, refined_mkpts_1, mconf[inds] diff --git a/imcui/third_party/rdd/RDD/matchers/dual_softmax_matcher.py b/imcui/third_party/rdd/RDD/matchers/dual_softmax_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..a0a5d6e73fa4950ea9556e07da7e570ea61c8ffa --- /dev/null +++ b/imcui/third_party/rdd/RDD/matchers/dual_softmax_matcher.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class DualSoftmaxMatcher(nn.Module): + def __init__(self, inv_temperature = 20, thr = 0.01): + super().__init__() + self.inv_temperature = inv_temperature + self.thr = thr + + def forward(self, info0, info1, thr = None): + desc0 = info0['descriptors'] + desc1 = info1['descriptors'] + + inds, P = self.dual_softmax(desc0, desc1, thr) + mkpts0 = info0['keypoints'][inds[:,0]] + mkpts1 = info1['keypoints'][inds[:,1]] + mconf = P[inds[:,0], inds[:,1]] + + return mkpts0, mkpts1, mconf + + def dual_softmax(self, desc0, desc1, thr = None): + if thr is None: + thr = self.thr + dist_mat = (desc0 @ desc1.t()) * self.inv_temperature + P = dist_mat.softmax(dim = -2) * dist_mat.softmax(dim= -1) + + inds = torch.nonzero((P == P.max(dim=-1, keepdim = True).values) + * (P == P.max(dim=-2, keepdim = True).values) * (P >= thr)) + + return inds, P \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/matchers/lightglue.py b/imcui/third_party/rdd/RDD/matchers/lightglue.py new file mode 100644 index 0000000000000000000000000000000000000000..9405083813fd499827c52ecf6d0520c40d1e5de8 --- /dev/null +++ b/imcui/third_party/rdd/RDD/matchers/lightglue.py @@ -0,0 +1,667 @@ +""" +Modified from +https://github.com/cvg/LightGlue +""" + +import warnings +from pathlib import Path +from types import SimpleNamespace +from typing import Callable, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +try: + from flash_attn.modules.mha import FlashCrossAttention +except ModuleNotFoundError: + FlashCrossAttention = None + +if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"): + FLASH_AVAILABLE = True +else: + FLASH_AVAILABLE = False + +torch.backends.cudnn.deterministic = True + + +@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) +def normalize_keypoints( + kpts: torch.Tensor, size: Optional[torch.Tensor] = None +) -> torch.Tensor: + if size is None: + size = 1 + kpts.max(-2).values - kpts.min(-2).values + elif not isinstance(size, torch.Tensor): + size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype) + size = size.to(kpts) + shift = size / 2 + scale = size.max(-1).values / 2 + kpts = (kpts - shift[..., None, :]) / scale[..., None, None] + return kpts + + +def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]: + if length <= x.shape[-2]: + return x, torch.ones_like(x[..., :1], dtype=torch.bool) + pad = torch.ones( + *x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype + ) + y = torch.cat([x, pad], dim=-2) + mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device) + mask[..., : x.shape[-2], :] = True + return y, mask + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x = x.unflatten(-1, (-1, 2)) + x1, x2 = x.unbind(dim=-1) + return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) + + +def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return (t * freqs[0]) + (rotate_half(t) * freqs[1]) + + +class LearnableFourierPositionalEncoding(nn.Module): + def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None: + super().__init__() + F_dim = F_dim if F_dim is not None else dim + self.gamma = gamma + self.Wr = nn.Linear(M, F_dim // 2, bias=False) + nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """encode position vector""" + projected = self.Wr(x) + cosines, sines = torch.cos(projected), torch.sin(projected) + emb = torch.stack([cosines, sines], 0).unsqueeze(-3) + return emb.repeat_interleave(2, dim=-1) + + +class TokenConfidence(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid()) + + def forward(self, desc0: torch.Tensor, desc1: torch.Tensor): + """get confidence tokens""" + return ( + self.token(desc0.detach()).squeeze(-1), + self.token(desc1.detach()).squeeze(-1), + ) + + +class Attention(nn.Module): + def __init__(self, allow_flash: bool) -> None: + super().__init__() + if allow_flash and not FLASH_AVAILABLE: + warnings.warn( + "FlashAttention is not available. For optimal speed, " + "consider installing torch >= 2.0 or flash-attn.", + stacklevel=2, + ) + self.enable_flash = allow_flash and FLASH_AVAILABLE + self.has_sdp = hasattr(F, "scaled_dot_product_attention") + if allow_flash and FlashCrossAttention: + self.flash_ = FlashCrossAttention() + if self.has_sdp: + torch.backends.cuda.enable_flash_sdp(allow_flash) + + def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + if q.shape[-2] == 0 or k.shape[-2] == 0: + return q.new_zeros((*q.shape[:-1], v.shape[-1])) + if self.enable_flash and q.device.type == "cuda": + # use torch 2.0 scaled_dot_product_attention with flash + if self.has_sdp: + args = [x.half().contiguous() for x in [q, k, v]] + v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype) + return v if mask is None else v.nan_to_num() + else: + assert mask is None + q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]] + m = self.flash_(q.half(), torch.stack([k, v], 2).half()) + return m.transpose(-2, -3).to(q.dtype).clone() + elif self.has_sdp: + args = [x.contiguous() for x in [q, k, v]] + v = F.scaled_dot_product_attention(*args, attn_mask=mask) + return v if mask is None else v.nan_to_num() + else: + s = q.shape[-1] ** -0.5 + sim = torch.einsum("...id,...jd->...ij", q, k) * s + if mask is not None: + sim.masked_fill(~mask, -float("inf")) + attn = F.softmax(sim, -1) + return torch.einsum("...ij,...jd->...id", attn, v) + + +class SelfBlock(nn.Module): + def __init__( + self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + assert self.embed_dim % num_heads == 0 + self.head_dim = self.embed_dim // num_heads + self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) + self.inner_attn = Attention(flash) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.ffn = nn.Sequential( + nn.Linear(2 * embed_dim, 2 * embed_dim), + nn.LayerNorm(2 * embed_dim, elementwise_affine=True), + nn.GELU(), + nn.Linear(2 * embed_dim, embed_dim), + ) + + def forward( + self, + x: torch.Tensor, + encoding: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qkv = self.Wqkv(x) + qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2) + q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2] + q = apply_cached_rotary_emb(encoding, q) + k = apply_cached_rotary_emb(encoding, k) + context = self.inner_attn(q, k, v, mask=mask) + message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2)) + return x + self.ffn(torch.cat([x, message], -1)) + + +class CrossBlock(nn.Module): + def __init__( + self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True + ) -> None: + super().__init__() + self.heads = num_heads + dim_head = embed_dim // num_heads + self.scale = dim_head**-0.5 + inner_dim = dim_head * num_heads + self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias) + self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias) + self.ffn = nn.Sequential( + nn.Linear(2 * embed_dim, 2 * embed_dim), + nn.LayerNorm(2 * embed_dim, elementwise_affine=True), + nn.GELU(), + nn.Linear(2 * embed_dim, embed_dim), + ) + if flash and FLASH_AVAILABLE: + self.flash = Attention(True) + else: + self.flash = None + + def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor): + return func(x0), func(x1) + + def forward( + self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> List[torch.Tensor]: + qk0, qk1 = self.map_(self.to_qk, x0, x1) + v0, v1 = self.map_(self.to_v, x0, x1) + qk0, qk1, v0, v1 = map( + lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2), + (qk0, qk1, v0, v1), + ) + if self.flash is not None and qk0.device.type == "cuda": + m0 = self.flash(qk0, qk1, v1, mask) + m1 = self.flash( + qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None + ) + else: + qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5 + sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1) + if mask is not None: + sim = sim.masked_fill(~mask, -float("inf")) + attn01 = F.softmax(sim, dim=-1) + attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1) + m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1) + m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0) + if mask is not None: + m0, m1 = m0.nan_to_num(), m1.nan_to_num() + m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1) + m0, m1 = self.map_(self.to_out, m0, m1) + x0 = x0 + self.ffn(torch.cat([x0, m0], -1)) + x1 = x1 + self.ffn(torch.cat([x1, m1], -1)) + return x0, x1 + + +class TransformerLayer(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.self_attn = SelfBlock(*args, **kwargs) + self.cross_attn = CrossBlock(*args, **kwargs) + + def forward( + self, + desc0, + desc1, + encoding0, + encoding1, + mask0: Optional[torch.Tensor] = None, + mask1: Optional[torch.Tensor] = None, + ): + if mask0 is not None and mask1 is not None: + return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1) + else: + desc0 = self.self_attn(desc0, encoding0) + desc1 = self.self_attn(desc1, encoding1) + return self.cross_attn(desc0, desc1) + + # This part is compiled and allows padding inputs + def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1): + mask = mask0 & mask1.transpose(-1, -2) + mask0 = mask0 & mask0.transpose(-1, -2) + mask1 = mask1 & mask1.transpose(-1, -2) + desc0 = self.self_attn(desc0, encoding0, mask0) + desc1 = self.self_attn(desc1, encoding1, mask1) + return self.cross_attn(desc0, desc1, mask) + + +def sigmoid_log_double_softmax( + sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor +) -> torch.Tensor: + """create the log assignment matrix from logits and similarity""" + b, m, n = sim.shape + certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2) + scores0 = F.log_softmax(sim, 2) + scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2) + scores = sim.new_full((b, m + 1, n + 1), 0) + scores[:, :m, :n] = scores0 + scores1 + certainties + scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1)) + scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1)) + return scores + + +class MatchAssignment(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + self.matchability = nn.Linear(dim, 1, bias=True) + self.final_proj = nn.Linear(dim, dim, bias=True) + + def forward(self, desc0: torch.Tensor, desc1: torch.Tensor): + """build assignment matrix from descriptors""" + mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) + _, _, d = mdesc0.shape + mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25 + sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1) + z0 = self.matchability(desc0) + z1 = self.matchability(desc1) + scores = sigmoid_log_double_softmax(sim, z0, z1) + return scores, sim + + def get_matchability(self, desc: torch.Tensor): + return torch.sigmoid(self.matchability(desc)).squeeze(-1) + + +def filter_matches(scores: torch.Tensor, th: float): + """obtain matches from a log assignment matrix [Bx M+1 x N+1]""" + max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) + m0, m1 = max0.indices, max1.indices + indices0 = torch.arange(m0.shape[1], device=m0.device)[None] + indices1 = torch.arange(m1.shape[1], device=m1.device)[None] + mutual0 = indices0 == m1.gather(1, m0) + mutual1 = indices1 == m0.gather(1, m1) + max0_exp = max0.values.exp() + zero = max0_exp.new_tensor(0) + mscores0 = torch.where(mutual0, max0_exp, zero) + mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero) + valid0 = mutual0 & (mscores0 > th) + valid1 = mutual1 & valid0.gather(1, m1) + m0 = torch.where(valid0, m0, -1) + m1 = torch.where(valid1, m1, -1) + return m0, m1, mscores0, mscores1 + + +class LightGlue(nn.Module): + default_conf = { + "name": "lightglue", # just for interfacing + "input_dim": 256, # input descriptor dimension (autoselected from weights) + "descriptor_dim": 256, + "add_scale_ori": False, + "n_layers": 9, + "num_heads": 4, + "flash": True, # enable FlashAttention if available. + "mp": False, # enable mixed precision + "depth_confidence": -1, # early stopping, disable with -1 + "width_confidence": -1, # point pruning, disable with -1 + "filter_threshold": 0.01, # match threshold + "weights": None, + } + + # Point pruning involves an overhead (gather). + # Therefore, we only activate it if there are enough keypoints. + pruning_keypoint_thresholds = { + "cpu": -1, + "mps": -1, + "cuda": 1024, + "flash": 1536, + } + + required_data_keys = ["image0", "image1"] + + version = "v0.1_arxiv" + url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth" + + features = { + "superpoint": { + "weights": "superpoint_lightglue", + "input_dim": 256, + }, + "disk": { + "weights": "disk_lightglue", + "input_dim": 128, + }, + "aliked": { + "weights": "aliked_lightglue", + "input_dim": 128, + }, + "sift": { + "weights": "sift_lightglue", + "input_dim": 128, + "add_scale_ori": True, + }, + "doghardnet": { + "weights": "doghardnet_lightglue", + "input_dim": 128, + "add_scale_ori": True, + }, + "rdd": { + "weights": './weights/RDD_lg-v2.pth', + "input_dim": 256, + }, + } + + def __init__(self, features="rdd", **conf) -> None: + super().__init__() + self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf}) + if features is not None: + if features not in self.features: + raise ValueError( + f"Unsupported features: {features} not in " + f"{{{','.join(self.features)}}}" + ) + for k, v in self.features[features].items(): + setattr(conf, k, v) + + if conf.input_dim != conf.descriptor_dim: + self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True) + else: + self.input_proj = nn.Identity() + + head_dim = conf.descriptor_dim // conf.num_heads + self.posenc = LearnableFourierPositionalEncoding( + 2 + 2 * self.conf.add_scale_ori, head_dim, head_dim + ) + + h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim + + self.transformers = nn.ModuleList( + [TransformerLayer(d, h, conf.flash) for _ in range(n)] + ) + + self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)]) + self.token_confidence = nn.ModuleList( + [TokenConfidence(d) for _ in range(n - 1)] + ) + self.register_buffer( + "confidence_thresholds", + torch.Tensor( + [self.confidence_threshold(i) for i in range(self.conf.n_layers)] + ), + ) + + state_dict = None + if features is not None and features != 'rdd': + fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth" + state_dict = torch.hub.load_state_dict_from_url( + self.url.format(self.version, features), file_name=fname + ) + self.load_state_dict(state_dict, strict=False) + elif conf.weights is not None: + if features == 'rdd': + path = Path(conf.weights) + else: + path = Path(__file__).parent + path = path / "weights/{}.pth".format(self.conf.weights) + state_dict = torch.load(str(path), map_location="cpu") + + if state_dict: + # rename old state dict entries + for i in range(self.conf.n_layers): + pattern = f"self_attn.{i}", f"transformers.{i}.self_attn" + state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} + pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn" + state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} + self.load_state_dict(state_dict, strict=False) + + # static lengths LightGlue is compiled for (only used with torch.compile) + self.static_lengths = None + + def compile( + self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536] + ): + if self.conf.width_confidence != -1: + warnings.warn( + "Point pruning is partially disabled for compiled forward.", + stacklevel=2, + ) + + torch._inductor.cudagraph_mark_step_begin() + for i in range(self.conf.n_layers): + self.transformers[i].masked_forward = torch.compile( + self.transformers[i].masked_forward, mode=mode, fullgraph=True + ) + + self.static_lengths = static_lengths + + def forward(self, data: dict) -> dict: + """ + Match keypoints and descriptors between two images + + Input (dict): + image0: dict + keypoints: [B x M x 2] + descriptors: [B x M x D] + image: [B x C x H x W] or image_size: [B x 2] + image1: dict + keypoints: [B x N x 2] + descriptors: [B x N x D] + image: [B x C x H x W] or image_size: [B x 2] + Output (dict): + matches0: [B x M] + matching_scores0: [B x M] + matches1: [B x N] + matching_scores1: [B x N] + matches: List[[Si x 2]] + scores: List[[Si]] + stop: int + prune0: [B x M] + prune1: [B x N] + """ + with torch.autocast(enabled=self.conf.mp, device_type="cuda"): + return self._forward(data) + + def _forward(self, data: dict) -> dict: + for key in self.required_data_keys: + assert key in data, f"Missing key {key} in data" + data0, data1 = data["image0"], data["image1"] + kpts0, kpts1 = data0["keypoints"], data1["keypoints"] + b, m, _ = kpts0.shape + b, n, _ = kpts1.shape + device = kpts0.device + size0, size1 = data0.get("image_size"), data1.get("image_size") + kpts0 = normalize_keypoints(kpts0, size0).clone() + kpts1 = normalize_keypoints(kpts1, size1).clone() + + if self.conf.add_scale_ori: + kpts0 = torch.cat( + [kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1 + ) + kpts1 = torch.cat( + [kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1 + ) + desc0 = data0["descriptors"].detach().contiguous() + desc1 = data1["descriptors"].detach().contiguous() + + assert desc0.shape[-1] == self.conf.input_dim + assert desc1.shape[-1] == self.conf.input_dim + + if torch.is_autocast_enabled(): + desc0 = desc0.half() + desc1 = desc1.half() + + mask0, mask1 = None, None + c = max(m, n) + do_compile = self.static_lengths and c <= max(self.static_lengths) + if do_compile: + kn = min([k for k in self.static_lengths if k >= c]) + desc0, mask0 = pad_to_length(desc0, kn) + desc1, mask1 = pad_to_length(desc1, kn) + kpts0, _ = pad_to_length(kpts0, kn) + kpts1, _ = pad_to_length(kpts1, kn) + desc0 = self.input_proj(desc0) + desc1 = self.input_proj(desc1) + # cache positional embeddings + encoding0 = self.posenc(kpts0) + encoding1 = self.posenc(kpts1) + + # GNN + final_proj + assignment + do_early_stop = self.conf.depth_confidence > 0 + do_point_pruning = self.conf.width_confidence > 0 and not do_compile + pruning_th = self.pruning_min_kpts(device) + if do_point_pruning: + ind0 = torch.arange(0, m, device=device)[None] + ind1 = torch.arange(0, n, device=device)[None] + # We store the index of the layer at which pruning is detected. + prune0 = torch.ones_like(ind0) + prune1 = torch.ones_like(ind1) + token0, token1 = None, None + for i in range(self.conf.n_layers): + if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints + break + desc0, desc1 = self.transformers[i]( + desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1 + ) + if i == self.conf.n_layers - 1: + continue # no early stopping or adaptive width at last layer + + if do_early_stop: + token0, token1 = self.token_confidence[i](desc0, desc1) + if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n): + break + if do_point_pruning and desc0.shape[-2] > pruning_th: + scores0 = self.log_assignment[i].get_matchability(desc0) + prunemask0 = self.get_pruning_mask(token0, scores0, i) + keep0 = torch.where(prunemask0)[1] + ind0 = ind0.index_select(1, keep0) + desc0 = desc0.index_select(1, keep0) + encoding0 = encoding0.index_select(-2, keep0) + prune0[:, ind0] += 1 + if do_point_pruning and desc1.shape[-2] > pruning_th: + scores1 = self.log_assignment[i].get_matchability(desc1) + prunemask1 = self.get_pruning_mask(token1, scores1, i) + keep1 = torch.where(prunemask1)[1] + ind1 = ind1.index_select(1, keep1) + desc1 = desc1.index_select(1, keep1) + encoding1 = encoding1.index_select(-2, keep1) + prune1[:, ind1] += 1 + + if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints + m0 = desc0.new_full((b, m), -1, dtype=torch.long) + m1 = desc1.new_full((b, n), -1, dtype=torch.long) + mscores0 = desc0.new_zeros((b, m)) + mscores1 = desc1.new_zeros((b, n)) + matches = desc0.new_empty((b, 0, 2), dtype=torch.long) + mscores = desc0.new_empty((b, 0)) + if not do_point_pruning: + prune0 = torch.ones_like(mscores0) * self.conf.n_layers + prune1 = torch.ones_like(mscores1) * self.conf.n_layers + return { + "matches0": m0, + "matches1": m1, + "matching_scores0": mscores0, + "matching_scores1": mscores1, + "stop": i + 1, + "matches": matches, + "scores": mscores, + "prune0": prune0, + "prune1": prune1, + } + + desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding + scores, _ = self.log_assignment[i](desc0, desc1) + m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold) + matches, mscores = [], [] + for k in range(b): + valid = m0[k] > -1 + m_indices_0 = torch.where(valid)[0] + m_indices_1 = m0[k][valid] + if do_point_pruning: + m_indices_0 = ind0[k, m_indices_0] + m_indices_1 = ind1[k, m_indices_1] + matches.append(torch.stack([m_indices_0, m_indices_1], -1)) + mscores.append(mscores0[k][valid]) + + # TODO: Remove when hloc switches to the compact format. + if do_point_pruning: + m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype) + m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype) + m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0))) + m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0))) + mscores0_ = torch.zeros((b, m), device=mscores0.device) + mscores1_ = torch.zeros((b, n), device=mscores1.device) + mscores0_[:, ind0] = mscores0 + mscores1_[:, ind1] = mscores1 + m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_ + else: + prune0 = torch.ones_like(mscores0) * self.conf.n_layers + prune1 = torch.ones_like(mscores1) * self.conf.n_layers + + return { + "matches0": m0, + "matches1": m1, + "matching_scores0": mscores0, + "matching_scores1": mscores1, + "stop": i + 1, + "matches": matches, + "scores": mscores, + "prune0": prune0, + "prune1": prune1, + } + + def confidence_threshold(self, layer_index: int) -> float: + """scaled confidence threshold""" + threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers) + return np.clip(threshold, 0, 1) + + def get_pruning_mask( + self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int + ) -> torch.Tensor: + """mask points which should be removed""" + keep = scores > (1 - self.conf.width_confidence) + if confidences is not None: # Low-confidence points are never pruned. + keep |= confidences <= self.confidence_thresholds[layer_index] + return keep + + def check_if_stop( + self, + confidences0: torch.Tensor, + confidences1: torch.Tensor, + layer_index: int, + num_points: int, + ) -> torch.Tensor: + """evaluate stopping condition""" + confidences = torch.cat([confidences0, confidences1], -1) + threshold = self.confidence_thresholds[layer_index] + ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points + return ratio_confident > self.conf.depth_confidence + + def pruning_min_kpts(self, device: torch.device): + if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda": + return self.pruning_keypoint_thresholds["flash"] + else: + return self.pruning_keypoint_thresholds[device.type] diff --git a/imcui/third_party/rdd/RDD/models/backbone.py b/imcui/third_party/rdd/RDD/models/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..8008e32929f0ac2c073f2e1b0847df085438195c --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/backbone.py @@ -0,0 +1,147 @@ +# Modified from Deformable DETR +# https://github.com/fundamentalvision/Deformable-DETR +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Backbone modules. +""" +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter +from typing import Dict, List +import torch.distributed as dist +from .position_encoding import build_position_encoding + +from ..utils.misc import NestedTensor, is_main_process + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n, eps=1e-5): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + self.eps = eps + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = self.eps + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool, n_layers = 4): + super().__init__() + for name, parameter in backbone.named_parameters(): + if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + parameter.requires_grad_(False) + if return_interm_layers: + if n_layers == 4: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + self.strides = [4, 8, 16, 32] + self.num_channels = [256, 512, 1024, 2048] + elif n_layers == 3: + return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} + self.strides = [8, 16, 32] + self.num_channels = [512, 1024, 2048] + else: + raise ValueError("n_layers should be 3 or 4") + + else: + return_layers = {'layer4': "0"} + self.strides = [32] + self.num_channels = [2048] + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + + def forward(self, tensor_list: NestedTensor): + xs = self.body(tensor_list.tensors) + out: Dict[str, NestedTensor] = {} + for name, x in xs.items(): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + out[name] = NestedTensor(x, mask) + return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool, + n_layers = 4): + norm_layer = FrozenBatchNorm2d + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + weights='ResNet50_Weights.IMAGENET1K_V1', norm_layer=norm_layer) + assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded" + super().__init__(backbone, train_backbone, return_interm_layers, n_layers) + if dilation: + self.strides[-1] = self.strides[-1] // 2 + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + self.strides = backbone.strides + self.num_channels = backbone.num_channels + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in sorted(xs.items()): + out.append(x) + + # position encoding + for x in out: + pos.append(self[1](x).to(x.tensors.dtype)) + + return out, pos + + +def build_backbone(config): + position_embedding = build_position_encoding(config) + train_backbone = config['lr_backbone'] > 0 + return_interm_layers = True + n_layers = config['num_feature_levels'] - 1 + backbone = Backbone('resnet50', train_backbone, return_interm_layers, False, n_layers=n_layers) + model = Joiner(backbone, position_embedding) + return model \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/models/deformable_transformer.py b/imcui/third_party/rdd/RDD/models/deformable_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ed2c4e238a7b631039a081ff7116d5d2d089723d --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/deformable_transformer.py @@ -0,0 +1,270 @@ +# Modified from Deformable DETR +# https://github.com/fundamentalvision/Deformable-DETR +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +import copy +from typing import Optional, List +import math + +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ + +from ..utils.misc import inverse_sigmoid +from .ops.modules import MSDeformAttn + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + +class DeformableTransformerEncoderLayer(nn.Module): + def __init__(self, + d_model=256, d_ffn=1024, + dropout=0.1, activation="relu", + n_levels=4, n_heads=8, n_points=4): + super().__init__() + + # self attention + self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout2 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout3 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, src): + src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) + src = src + self.dropout3(src2) + src = self.norm2(src) + return src + + def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None, debug=False): + # self attention + if debug: + src2, sampled_points = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask) + else: + src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask) + src = src + self.dropout1(src2) + src = self.norm1(src) + + # ffn + src = self.forward_ffn(src) + if debug: + return src, sampled_points + return src + +class DeformableTransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + + ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None, debug=False): + output = src + reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) + for _, layer in enumerate(self.layers): + if debug: + output, sampled_points = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask, debug=debug) + else: + output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) + if debug: + return output, reference_points, sampled_points + return output + +class DecoderLayer(nn.Module): + def __init__(self, d_model=256, n_head=8, dropout=0.1): + super().__init__() + self.nhead = n_head + self.dim = d_model // n_head + self.attention = LinearAttention() + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, tgt, src, tgt_mask=None, src_mask=None): + + bs = tgt.size(0) + query, key, value = tgt, src, src + + query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] + key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] + value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) + + tgt2 = self.attention(query, key, value, q_mask=tgt_mask, kv_mask=src_mask) + tgt2 = tgt2.view(bs, -1, self.nhead*self.dim) + tgt2 = self.norm1(self.dropout1(tgt2)) + tgt2 = self.mlp(torch.cat([tgt, tgt2], dim=2)) + + tgt2 = self.norm2(tgt2) + + return tgt + tgt2 + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class Decoder(nn.Module): + def __init__(self, decoder_layer, num_layers): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + + def forward(self, tgt, memory, tgt_mask=None, memory_mask=None): + for layer in self.layers: + tgt = layer(tgt, memory, tgt_mask=tgt_mask, src_mask=memory_mask) + + return tgt + +import math + +class DeformableTransformer(nn.Module): + def __init__(self, d_model=256, nhead=8, + num_encoder_layers=4, dim_feedforward=1024, dropout=0.1, + activation="relu", + num_feature_levels=5, enc_n_points=8): + super().__init__() + + self.d_model = d_model + self.nhead = nhead + + # Encoder and Decoder + encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward, + dropout, activation, + num_feature_levels, nhead, enc_n_points) + self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers) + + # Embedding for feature levels (multi-scale) + self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + normal_(self.level_embed) + + def get_valid_ratio(self, mask): + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def forward(self, srcs, masks, pos_embeds): + + # Prepare inputs for encoder + src_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): + bs, c, h, w = src.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + src = src.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2).to(src.device) + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + src_flatten.append(src) + mask_flatten.append(mask) + src_flatten = torch.cat(src_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + + memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten, debug=False) + + return memory, spatial_shapes, level_start_index + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +def build_deforamble_transformer(config): + return DeformableTransformer(d_model=config['d_model'], nhead=config['nhead'], + num_encoder_layers=config['num_encoder_layers'], dim_feedforward=config['dim_feedforward'], dropout=config['dropout'], + activation=config['activation'], + num_feature_levels=config['num_feature_levels'], enc_n_points=config['enc_n_points']) \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/models/descriptor.py b/imcui/third_party/rdd/RDD/models/descriptor.py new file mode 100644 index 0000000000000000000000000000000000000000..b86215d9ca2d2521f5e099cc4419a69b93e9fe86 --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/descriptor.py @@ -0,0 +1,116 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..utils.misc import NestedTensor, nested_tensor_from_tensor_list +import torchvision.transforms as transforms +from .backbone import build_backbone +from .deformable_transformer import build_deforamble_transformer + +class BasicLayer(nn.Module): + """ + Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU + """ + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False): + super().__init__() + self.layer = nn.Sequential( + nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias), + nn.BatchNorm2d(out_channels, affine=False), + nn.ReLU(inplace = False), + ) + + def forward(self, x): + return self.layer(x) + +class RDD_Descriptor(nn.Module): + def __init__(self, backbone, transformer, num_feature_levels): + super().__init__() + self.transformer = transformer + self.hidden_dim = transformer.d_model + self.num_feature_levels = num_feature_levels + + self.matchibility_head = nn.Sequential( + BasicLayer(256, 128, 1, padding=0), + BasicLayer(128, 64, 1, padding=0), + nn.Conv2d (64, 1, 1), + nn.Sigmoid() + ) + + if num_feature_levels > 1: + num_backbone_outs = len(backbone.strides) + input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = backbone.num_channels[_] + input_proj_list.append(nn.Sequential( + nn.Conv2d(in_channels, self.hidden_dim, kernel_size=1), + nn.GroupNorm(32, self.hidden_dim), + )) + for _ in range(num_feature_levels - num_backbone_outs): + input_proj_list.append(nn.Sequential( + nn.Conv2d(in_channels, self.hidden_dim, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(32, self.hidden_dim), + )) + in_channels = self.hidden_dim + self.input_proj = nn.ModuleList(input_proj_list) + else: + self.input_proj = nn.ModuleList([ + nn.Sequential( + nn.Conv2d(backbone.num_channels[0], self.hidden_dim, kernel_size=1), + nn.GroupNorm(32, self.hidden_dim), + )]) + self.backbone = backbone + self.stride = backbone.strides[0] + for proj in self.input_proj: + nn.init.xavier_uniform_(proj[0].weight, gain=1) + nn.init.constant_(proj[0].bias, 0) + + def forward(self, samples: NestedTensor): + + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + + features, pos = self.backbone(samples) + + srcs = [] + masks = [] + for l, feat in enumerate(features): + src, mask = feat.decompose() + srcs.append(self.input_proj[l](src)) + masks.append(mask) + assert mask is not None + if self.num_feature_levels > len(srcs): + _len_srcs = len(srcs) + for l in range(_len_srcs, self.num_feature_levels): + if l == _len_srcs: + src = self.input_proj[l](features[-1].tensors) + else: + src = self.input_proj[l](srcs[-1]) + m = samples.mask + mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] + pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) + srcs.append(src) + masks.append(mask) + pos.append(pos_l) + + flatten_feats, spatial_shapes, level_start_index = self.transformer(srcs, masks, pos) + # Reshape the flattened features back to the original spatial shapes + feats = [] + level_start_index = torch.cat((level_start_index, torch.tensor([flatten_feats.shape[1]+1]).to(level_start_index.device))) + for i, shape in enumerate(spatial_shapes): + assert len(shape) == 2 + temp = flatten_feats[:, level_start_index[i] : level_start_index[i+1], :] + feats.append(temp.transpose(1, 2).view(-1, self.hidden_dim, *shape)) + + # Sum up the features from different levels + final_feature = feats[0] + for feat in feats[1:]: + final_feature = final_feature + F.interpolate(feat, size=final_feature.shape[-2:], mode='bilinear', align_corners=True) + + matchibility = self.matchibility_head(final_feature) + + return final_feature, matchibility + + +def build_descriptor(config): + backbone = build_backbone(config) + transformer = build_deforamble_transformer(config) + return RDD_Descriptor(backbone, transformer, config['num_feature_levels']) \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/models/detector.py b/imcui/third_party/rdd/RDD/models/detector.py new file mode 100644 index 0000000000000000000000000000000000000000..7c93ce3a011ee334113639ee9f52419405fd57f2 --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/detector.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.models import resnet +from typing import Optional, Callable +from ..utils.misc import NestedTensor + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, + gate: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None): + super().__init__() + if gate is None: + self.gate = nn.ReLU(inplace=False) + else: + self.gate = gate + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.conv1 = resnet.conv3x3(in_channels, out_channels) + self.bn1 = norm_layer(out_channels) + self.conv2 = resnet.conv3x3(out_channels, out_channels) + self.bn2 = norm_layer(out_channels) + + def forward(self, x): + x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W + x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W + return x + +class ResBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + gate: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResBlock, self).__init__() + if gate is None: + self.gate = nn.ReLU(inplace=False) + else: + self.gate = gate + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('ResBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in ResBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = resnet.conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.conv2 = resnet.conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.gate(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = out + identity + out = self.gate(out) + + return out + +class RDD_detector(nn.Module): + def __init__(self, block_dims, hidden_dim=128): + super().__init__() + self.gate = nn.ReLU(inplace=False) + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.pool4 = nn.MaxPool2d(kernel_size=4, stride=4) + self.block1 = ConvBlock(3, block_dims[0], self.gate, nn.BatchNorm2d) + self.block2 = ResBlock(inplanes=block_dims[0], planes=block_dims[1], stride=1, + downsample=nn.Conv2d(block_dims[0], block_dims[1], 1), + gate=self.gate, + norm_layer=nn.BatchNorm2d) + self.block3 = ResBlock(inplanes=block_dims[1], planes=block_dims[2], stride=1, + downsample=nn.Conv2d(block_dims[1], block_dims[2], 1), + gate=self.gate, + norm_layer=nn.BatchNorm2d) + self.block4 = ResBlock(inplanes=block_dims[2], planes=block_dims[3], stride=1, + downsample=nn.Conv2d(block_dims[2], block_dims[3], 1), + gate=self.gate, + norm_layer=nn.BatchNorm2d) + + self.conv1 = resnet.conv1x1(block_dims[0], hidden_dim // 4) + self.conv2 = resnet.conv1x1(block_dims[1], hidden_dim // 4) + self.conv3 = resnet.conv1x1(block_dims[2], hidden_dim // 4) + self.conv4 = resnet.conv1x1(block_dims[3], hidden_dim // 4) + + self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) + self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) + self.upsample32 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True) + + self.convhead2 = nn.Sequential( + resnet.conv1x1(hidden_dim, 1), + nn.Sigmoid() + ) + + def forward(self, samples: NestedTensor): + x1 = self.block1(samples.tensors) + x2 = self.pool2(x1) + x2 = self.block2(x2) # B x c2 x H/2 x W/2 + x3 = self.pool4(x2) + x3 = self.block3(x3) # B x c3 x H/8 x W/8 + x4 = self.pool4(x3) + x4 = self.block4(x4) + + x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W + x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2 + x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8 + x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32 + + x2_up = self.upsample2(x2) + x3_up = self.upsample8(x3) + x4_up = self.upsample32(x4) + + x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1) + scoremap = self.convhead2(x1234) + + return scoremap + +def build_detector(config): + block_dims = config['block_dims'] + return RDD_detector(block_dims, block_dims[-1]) \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/models/interpolator.py b/imcui/third_party/rdd/RDD/models/interpolator.py new file mode 100644 index 0000000000000000000000000000000000000000..608fdb176b1bbfd36a92b39b38133d8488da7553 --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/interpolator.py @@ -0,0 +1,33 @@ +""" + "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024." + https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/ +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class InterpolateSparse2d(nn.Module): + """ Efficiently interpolate tensor at given sparse 2D positions. """ + def __init__(self, mode = 'bilinear', align_corners = True): + super().__init__() + self.mode = mode + self.align_corners = align_corners + + def normgrid(self, x, H, W): + """ Normalize coords to [-1,1]. """ + return 2. * (x/(torch.tensor([W-1, H-1], device = x.device, dtype = x.dtype))) - 1. + + def forward(self, x, pos, H, W): + """ + Input + x: [B, C, H, W] feature tensor + pos: [B, N, 2] tensor of positions + H, W: int, original resolution of input 2d positions -- used in normalization [-1,1] + + Returns + [B, N, C] sampled channels at 2d positions + """ + grid = self.normgrid(pos, H, W).unsqueeze(-2).to(x.dtype) + x = F.grid_sample(x, grid, mode = self.mode , align_corners = self.align_corners) + return x.permute(0,2,3,1).squeeze(-2) \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/models/ops/functions/__init__.py b/imcui/third_party/rdd/RDD/models/ops/functions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2b06b5ac538b63bdb9a6c82e4635b95bb5491d5b --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/ops/functions/__init__.py @@ -0,0 +1,13 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR + +from .ms_deform_attn_func import MSDeformAttnFunction + diff --git a/imcui/third_party/rdd/RDD/models/ops/functions/ms_deform_attn_func.py b/imcui/third_party/rdd/RDD/models/ops/functions/ms_deform_attn_func.py new file mode 100644 index 0000000000000000000000000000000000000000..13064ccacf430fe94ed8207fc70243c9b72aceee --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/ops/functions/ms_deform_attn_func.py @@ -0,0 +1,72 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +try: + import MultiScaleDeformableAttention as MSDA +except ModuleNotFoundError as e: + info_string = ( + "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n" + "\t`cd mask2former/modeling/pixel_decoder/ops`\n" + "\t`sh make.sh`\n" + ) + print(info_string) + + +class MSDeformAttnFunction(Function): + @staticmethod + def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): + ctx.im2col_step = im2col_step + output = MSDA.ms_deform_attn_forward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) + ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = \ + MSDA.ms_deform_attn_backward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, + mode='bilinear', padding_mode='zeros', align_corners=False) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) + return output.transpose(1, 2).contiguous() diff --git a/imcui/third_party/rdd/RDD/models/ops/make.sh b/imcui/third_party/rdd/RDD/models/ops/make.sh new file mode 100644 index 0000000000000000000000000000000000000000..7b38cdbf48f3571d986a33e7563b517952b51bb2 --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/ops/make.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR + +python setup.py build install diff --git a/imcui/third_party/rdd/RDD/models/ops/modules/__init__.py b/imcui/third_party/rdd/RDD/models/ops/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6fdbf03359958f3d67ab00f879bf6b61a6c8f06a --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/ops/modules/__init__.py @@ -0,0 +1,12 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR + +from .ms_deform_attn import MSDeformAttn diff --git a/imcui/third_party/rdd/RDD/models/ops/modules/ms_deform_attn.py b/imcui/third_party/rdd/RDD/models/ops/modules/ms_deform_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..e7b4c42ea504a0859ccadd72646919c941e72f73 --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/ops/modules/ms_deform_attn.py @@ -0,0 +1,125 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import warnings +import math + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import xavier_uniform_, constant_ + +from ..functions import MSDeformAttnFunction +from ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + return (n & (n-1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): + """ + Multi-Scale Deformable Attention Module + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation.") + + self.im2col_step = 128 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.) + constant_(self.attention_weights.bias.data, 0.) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.) + + def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) + # N, Len_q, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + else: + raise ValueError( + 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) + try: + output = MSDeformAttnFunction.apply( + value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) + except: + # CPU + output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) + # # For FLOPs calculation only + # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) + output = self.output_proj(output) + return output diff --git a/imcui/third_party/rdd/RDD/models/ops/setup.py b/imcui/third_party/rdd/RDD/models/ops/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..3b57ad313ac8f9b6586892142da8ba943e516cec --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/ops/setup.py @@ -0,0 +1,78 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR + +import os +import glob + +import torch + +from torch.utils.cpp_extension import CUDA_HOME +from torch.utils.cpp_extension import CppExtension +from torch.utils.cpp_extension import CUDAExtension + +from setuptools import find_packages +from setuptools import setup + +requirements = ["torch", "torchvision"] + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "src") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + extra_compile_args = {"cxx": []} + define_macros = [] + + # Force cuda since torch ask for a device, not if cuda is in fact available. + if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + else: + if CUDA_HOME is None: + raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.') + else: + raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().') + + sources = [os.path.join(extensions_dir, s) for s in sources] + include_dirs = [extensions_dir] + ext_modules = [ + extension( + "MultiScaleDeformableAttention", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + return ext_modules + +setup( + name="MultiScaleDeformableAttention", + version="1.0", + author="Weijie Su", + url="https://github.com/fundamentalvision/Deformable-DETR", + description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", + packages=find_packages(exclude=("configs", "tests",)), + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/imcui/third_party/rdd/RDD/models/ops/src/cpu/ms_deform_attn_cpu.cpp b/imcui/third_party/rdd/RDD/models/ops/src/cpu/ms_deform_attn_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..48757e2b0156b2c1513b615d2a17e5aee5172ae7 --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/ops/src/cpu/ms_deform_attn_cpu.cpp @@ -0,0 +1,46 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +/*! +* Copyright (c) Facebook, Inc. and its affiliates. +* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR +*/ + +#include + +#include +#include + + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + diff --git a/imcui/third_party/rdd/RDD/models/ops/src/cpu/ms_deform_attn_cpu.h b/imcui/third_party/rdd/RDD/models/ops/src/cpu/ms_deform_attn_cpu.h new file mode 100644 index 0000000000000000000000000000000000000000..51bb27e9ee828f967e8aa854c2d55574040c6d7e --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/ops/src/cpu/ms_deform_attn_cpu.h @@ -0,0 +1,38 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +/*! +* Copyright (c) Facebook, Inc. and its affiliates. +* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR +*/ + +#pragma once +#include + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + + diff --git a/imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_attn_cuda.cu b/imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_attn_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..0c465dab3d636dfd6a44523c63f148b6e15084d9 --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_attn_cuda.cu @@ -0,0 +1,158 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +/*! +* Copyright (c) Facebook, Inc. and its affiliates. +* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR +*/ + +#include +#include "cuda/ms_deform_im2col_cuda.cuh" + +#include +#include +#include +#include + + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + columns.data()); + + })); + } + + output = output.view({batch, num_query, num_heads*channels}); + + return output; +} + + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto grad_value = at::zeros_like(value); + auto grad_sampling_loc = at::zeros_like(sampling_loc); + auto grad_attn_weight = at::zeros_like(attn_weight); + + const int batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), + grad_output_g.data(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + grad_value.data() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); + + })); + } + + return { + grad_value, grad_sampling_loc, grad_attn_weight + }; +} \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_attn_cuda.h b/imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_attn_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..4f0658e8668a11f0e7d71deff9adac71884f2e87 --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_attn_cuda.h @@ -0,0 +1,35 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +/*! +* Copyright (c) Facebook, Inc. and its affiliates. +* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR +*/ + +#pragma once +#include + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + diff --git a/imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_im2col_cuda.cuh b/imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_im2col_cuda.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c04e0d4ab97d25c1756fcd8d08dd1e5a6d280b7c --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_im2col_cuda.cuh @@ -0,0 +1,1332 @@ +/*! +************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************** +* Modified from DCN (https://github.com/msracver/Deformable-ConvNets) +* Copyright (c) 2018 Microsoft +************************************************************************** +*/ + +/*! +* Copyright (c) Facebook, Inc. and its affiliates. +* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR +*/ + +#include +#include +#include + +#include +#include + +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) +{ + return (N + num_threads - 1) / num_threads; +} + + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + + +template +__global__ void ms_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockSize; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockSize/2; s>0; s>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +void ms_deformable_im2col_cuda(cudaStream_t stream, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* data_col) +{ + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +template +void ms_deformable_col2im_cuda(cudaStream_t stream, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t * data_spatial_shapes, + const int64_t * data_level_start_index, + const scalar_t * data_sampling_loc, + const scalar_t * data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) + { + if ((channels & 1023) == 0) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_gm + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + else{ + switch(channels) + { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + default: + if (channels < 64) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/models/ops/src/ms_deform_attn.h b/imcui/third_party/rdd/RDD/models/ops/src/ms_deform_attn.h new file mode 100644 index 0000000000000000000000000000000000000000..2f80a1b294c55b37d13bb3558ff7aeadba3b37de --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/ops/src/ms_deform_attn.h @@ -0,0 +1,67 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +/*! +* Copyright (c) Facebook, Inc. and its affiliates. +* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR +*/ + +#pragma once + +#include "cpu/ms_deform_attn_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/ms_deform_attn_cuda.h" +#endif + + +at::Tensor +ms_deform_attn_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_forward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector +ms_deform_attn_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_backward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + diff --git a/imcui/third_party/rdd/RDD/models/ops/src/vision.cpp b/imcui/third_party/rdd/RDD/models/ops/src/vision.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4a08821e0121a77556aa7a263ec8ebfa928b13b6 --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/ops/src/vision.cpp @@ -0,0 +1,21 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +/*! +* Copyright (c) Facebook, Inc. and its affiliates. +* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR +*/ + +#include "ms_deform_attn.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); + m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); +} diff --git a/imcui/third_party/rdd/RDD/models/ops/test.py b/imcui/third_party/rdd/RDD/models/ops/test.py new file mode 100644 index 0000000000000000000000000000000000000000..6e1b545459f6fd3235767e721eb5a1090ae14bef --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/ops/test.py @@ -0,0 +1,92 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import time +import torch +import torch.nn as nn +from torch.autograd import gradcheck + +from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch + + +N, M, D = 1, 2, 2 +Lq, L, P = 2, 2, 2 +shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() +level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) +S = sum([(H*W).item() for H, W in shapes]) + + +torch.manual_seed(3) + + +@torch.no_grad() +def check_forward_equal_with_pytorch_double(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() + output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() + fwdok = torch.allclose(output_cuda, output_pytorch) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +@torch.no_grad() +def check_forward_equal_with_pytorch_float(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() + output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() + fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): + + value = torch.rand(N, S, M, channels).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + func = MSDeformAttnFunction.apply + + value.requires_grad = grad_value + sampling_locations.requires_grad = grad_sampling_loc + attention_weights.requires_grad = grad_attn_weight + + gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) + + print(f'* {gradok} check_gradient_numerical(D={channels})') + + +if __name__ == '__main__': + check_forward_equal_with_pytorch_double() + check_forward_equal_with_pytorch_float() + + for channels in [30, 32, 64, 71, 1025, 2048, 3096]: + check_gradient_numerical(channels, True, True, True) + + + diff --git a/imcui/third_party/rdd/RDD/models/position_encoding.py b/imcui/third_party/rdd/RDD/models/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..9b8d8b1992004989cceec248be84530ccca41a2a --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/position_encoding.py @@ -0,0 +1,48 @@ +import math +import torch +from torch import nn +from ..utils.misc import NestedTensor + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + +def build_position_encoding(config): + N_steps = config['hidden_dim'] // 2 + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + return position_embedding \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/models/soft_detect.py b/imcui/third_party/rdd/RDD/models/soft_detect.py new file mode 100644 index 0000000000000000000000000000000000000000..2472ec8ad6577c94e423bc651c6e87ef29bb4583 --- /dev/null +++ b/imcui/third_party/rdd/RDD/models/soft_detect.py @@ -0,0 +1,250 @@ +# ALIKE: https://github.com/Shiaoming/ALIKE +import torch +from torch import nn +import numpy as np +import torch.nn.functional as F + + +# coordinates system +# ------------------------------> [ x: range=-1.0~1.0; w: range=0~W ] +# | ----------------------------- +# | | | +# | | | +# | | | +# | | image | +# | | | +# | | | +# | | | +# | |---------------------------| +# v +# [ y: range=-1.0~1.0; h: range=0~H ] + +def simple_nms(scores, nms_radius: int): + """ Fast Non-maximum suppression to remove nearby points """ + assert (nms_radius >= 0) + + def max_pool(x): + return torch.nn.functional.max_pool2d( + x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius) + + zeros = torch.zeros_like(scores) + max_mask = scores == max_pool(scores) + + for _ in range(2): + supp_mask = max_pool(max_mask.float()) > 0 + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == max_pool(supp_scores) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + + +""" + "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024." + https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/ +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class InterpolateSparse2d(nn.Module): + """ Efficiently interpolate tensor at given sparse 2D positions. """ + def __init__(self, mode = 'bicubic', align_corners = False): + super().__init__() + self.mode = mode + self.align_corners = align_corners + + def normgrid(self, x, H, W): + """ Normalize coords to [-1,1]. """ + return 2. * (x/(torch.tensor([W-1, H-1], device = x.device, dtype = x.dtype))) - 1. + + def forward(self, x, pos, H, W): + """ + Input + x: [B, C, H, W] feature tensor + pos: [B, N, 2] tensor of positions + H, W: int, original resolution of input 2d positions -- used in normalization [-1,1] + + Returns + [B, N, C] sampled channels at 2d positions + """ + grid = self.normgrid(pos, H, W).unsqueeze(-2).to(x.dtype) + x = F.grid_sample(x, grid, mode = self.mode , align_corners = False) + return x.permute(0,2,3,1).squeeze(-2) + + +class SoftDetect(nn.Module): + def __init__(self, radius=2, top_k=0, scores_th=0.2, n_limit=20000): + """ + Args: + radius: soft detection radius, kernel size is (2 * radius + 1) + top_k: top_k > 0: return top k keypoints + scores_th: top_k <= 0 threshold mode: scores_th > 0: return keypoints with scores>scores_th + else: return keypoints with scores > scores.mean() + n_limit: max number of keypoint in threshold mode + """ + super().__init__() + self.radius = radius + self.top_k = top_k + self.scores_th = scores_th + self.n_limit = n_limit + self.kernel_size = 2 * self.radius + 1 + self.temperature = 0.1 # tuned temperature + self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius) + self.sample_descriptor = InterpolateSparse2d('bicubic') + # local xy grid + x = torch.linspace(-self.radius, self.radius, self.kernel_size) + # (kernel_size*kernel_size) x 2 : (w,h) + self.hw_grid = torch.stack(torch.meshgrid([x, x])).view(2, -1).t()[:, [1, 0]] + + def detect_keypoints(self, scores_map, normalized_coordinates=True): + b, c, h, w = scores_map.shape + scores_nograd = scores_map.detach() + + # nms_scores = simple_nms(scores_nograd, self.radius) + nms_scores = simple_nms(scores_nograd, 2) + + # remove border + nms_scores[:, :, :self.radius + 1, :] = 0 + nms_scores[:, :, :, :self.radius + 1] = 0 + nms_scores[:, :, h - self.radius:, :] = 0 + nms_scores[:, :, :, w - self.radius:] = 0 + + # detect keypoints without grad + if self.top_k > 0: + topk = torch.topk(nms_scores.view(b, -1), self.top_k) + indices_keypoints = topk.indices # B x top_k + else: + if self.scores_th > 0: + masks = nms_scores > self.scores_th + if masks.sum() == 0: + th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th + masks = nms_scores > th.reshape(b, 1, 1, 1) + else: + th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th + masks = nms_scores > th.reshape(b, 1, 1, 1) + masks = masks.reshape(b, -1) + + indices_keypoints = [] # list, B x (any size) + scores_view = scores_nograd.reshape(b, -1) + for mask, scores in zip(masks, scores_view): + indices = mask.nonzero(as_tuple=False)[:, 0] + if len(indices) > self.n_limit: + kpts_sc = scores[indices] + sort_idx = kpts_sc.sort(descending=True)[1] + sel_idx = sort_idx[:self.n_limit] + indices = indices[sel_idx] + indices_keypoints.append(indices) + + # detect soft keypoints with grad backpropagation + patches = self.unfold(scores_map) # B x (kernel**2) x (H*W) + self.hw_grid = self.hw_grid.to(patches) # to device + keypoints = [] + scoredispersitys = [] + kptscores = [] + for b_idx in range(b): + patch = patches[b_idx].t() # (H*W) x (kernel**2) + indices_kpt = indices_keypoints[b_idx] # one dimension vector, say its size is M + patch_scores = patch[indices_kpt] # M x (kernel**2) + + # max is detached to prevent undesired backprop loops in the graph + max_v = patch_scores.max(dim=1).values.detach()[:, None] + x_exp = ((patch_scores - max_v) / self.temperature).exp() # M * (kernel**2), in [0, 1] + + # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} } + xy_residual = x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None] # Soft-argmax, Mx2 + + hw_grid_dist2 = torch.norm((self.hw_grid[None, :, :] - xy_residual[:, None, :]) / self.radius, + dim=-1) ** 2 + scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1) + + # compute result keypoints + keypoints_xy_nms = torch.stack([indices_kpt % w, indices_kpt // w], dim=1) # Mx2 + keypoints_xy = keypoints_xy_nms + xy_residual + if normalized_coordinates: + keypoints_xy = keypoints_xy / keypoints_xy.new_tensor([w - 1, h - 1]) * 2 - 1 # (w,h) -> (-1~1,-1~1) + + kptscore = torch.nn.functional.grid_sample(scores_map[b_idx].unsqueeze(0), keypoints_xy.view(1, 1, -1, 2), + mode='bilinear', align_corners=True)[0, 0, 0, :] # CxN + + keypoints.append(keypoints_xy) + scoredispersitys.append(scoredispersity) + kptscores.append(kptscore) + + return keypoints, scoredispersitys, kptscores + + def forward(self, scores_map, normalized_coordinates=True): + """ + :param scores_map: Bx1xHxW + :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1.0 ~ 1.0 + """ + B, _, H, W = scores_map.shape + + keypoints, scoredispersitys, kptscores = self.detect_keypoints(scores_map, + normalized_coordinates) + + # keypoints: B M 2 + # scoredispersitys: + return keypoints, kptscores, scoredispersitys + +import torch +import torch.nn as nn + +class Detect(nn.Module): + def __init__(self, stride=4, top_k=0, scores_th=0, n_limit=20000): + super().__init__() + self.stride = stride + self.top_k = top_k + self.scores_th = scores_th + self.n_limit = n_limit + + def forward(self, scores, coords, w, h): + """ + scores: B x N x 1 (keypoint confidence scores) + coords: B x N x 2 (offsets within stride x stride window) + w, h: Image dimensions + """ + b, n, _ = scores.shape + kpts_list = [] + scores_list = [] + + for b_idx in range(b): + score = scores[b_idx].squeeze(-1) # Shape: (N,) + coord = coords[b_idx] # Shape: (N, 2) + + # Apply score thresholding + if self.scores_th >= 0: + valid = score > self.scores_th + else: + valid = score > score.mean() + + valid_indices = valid.nonzero(as_tuple=True)[0] # Get valid indices + if valid_indices.numel() == 0: + kpts_list.append(torch.empty((0, 2), device=scores.device)) + scores_list.append(torch.empty((0,), device=scores.device)) + continue + + # Compute keypoint locations in original image space + i_ids = valid_indices # Indices where keypoints exist + kpts = torch.stack([i_ids % w, i_ids // w], dim=1).to(torch.float) * self.stride # Grid position + kpts += coord[i_ids] * self.stride # Apply offset + + # Normalize keypoints to [-1, 1] range + kpts = (kpts / torch.tensor([w - 1, h - 1], device=kpts.device, dtype=kpts.dtype)) * 2 - 1 + + # Filter top-k keypoints if needed + scores_valid = score[valid_indices] + if self.top_k > 0 and len(kpts) > self.top_k: + topk = torch.topk(scores_valid, self.top_k, dim=0) + kpts = kpts[topk.indices] + scores_valid = topk.values + elif self.top_k < 0: + if len(kpts) > self.n_limit: + sorted_idx = scores_valid.argsort(descending=True)[:self.n_limit] + kpts = kpts[sorted_idx] + scores_valid = scores_valid[sorted_idx] + + kpts_list.append(kpts) + scores_list.append(scores_valid) + + return kpts_list, scores_list \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/utils/__init__.py b/imcui/third_party/rdd/RDD/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..45d3b66b8857cbbd8e8bce38dc2117ada1bc536c --- /dev/null +++ b/imcui/third_party/rdd/RDD/utils/__init__.py @@ -0,0 +1 @@ +from .misc import * \ No newline at end of file diff --git a/imcui/third_party/rdd/RDD/utils/misc.py b/imcui/third_party/rdd/RDD/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..d2ee94bc4d066a11bb475ae465dc9351390e2c92 --- /dev/null +++ b/imcui/third_party/rdd/RDD/utils/misc.py @@ -0,0 +1,531 @@ +# Modified from Deformable DETR +# https://github.com/fundamentalvision/Deformable-DETR +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ +import os +import subprocess +import time +from collections import defaultdict, deque +import datetime +import pickle +from typing import Optional, List +import yaml +import torch +import torch.nn as nn +import torch.distributed as dist +from torch import Tensor + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +if float(torchvision.__version__.split('.')[1]) < 0.5: + import math + from torchvision.ops.misc import _NewEmptyTensorOp + def _check_size_scale_factor(dim, size, scale_factor): + # type: (int, Optional[List[int]], Optional[float]) -> None + if size is None and scale_factor is None: + raise ValueError("either size or scale_factor should be defined") + if size is not None and scale_factor is not None: + raise ValueError("only one of size or scale_factor should be defined") + if not (scale_factor is not None and len(scale_factor) != dim): + raise ValueError( + "scale_factor shape must match input shape. " + "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor)) + ) + def _output_size(dim, input, size, scale_factor): + # type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int] + assert dim == 2 + _check_size_scale_factor(dim, size, scale_factor) + if size is not None: + return size + # if dim is not 2 or scale_factor is iterable use _ntuple instead of concat + assert scale_factor is not None and isinstance(scale_factor, (int, float)) + scale_factors = [scale_factor, scale_factor] + # math.floor might return float in py2.7 + return [ + int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim) + ] +elif float(torchvision.__version__.split('.')[1]) < 7: + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device, non_blocking=False): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device, non_blocking=non_blocking) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device, non_blocking=non_blocking) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def record_stream(self, *args, **kwargs): + self.tensors.record_stream(*args, **kwargs) + if self.mask is not None: + self.mask.record_stream(*args, **kwargs) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def get_local_size(): + if not is_dist_avail_and_initialized(): + return 1 + return int(os.environ['LOCAL_SIZE']) + + +def get_local_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return int(os.environ['LOCAL_RANK']) + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + args.dist_url = 'env://' + os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) + elif 'SLURM_PROCID' in os.environ: + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + addr = subprocess.getoutput( + 'scontrol show hostname {} | head -n1'.format(node_list)) + os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['RANK'] = str(proc_id) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['LOCAL_SIZE'] = str(num_gpus) + args.dist_url = 'env://' + args.world_size = ntasks + args.rank = proc_id + args.gpu = proc_id % num_gpus + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if float(torchvision.__version__[:3]) < 0.7: + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + if float(torchvision.__version__[:3]) < 0.5: + return _NewEmptyTensorOp.apply(input, output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) + + +def get_total_grad_norm(parameters, norm_type=2): + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + device = parameters[0].grad.device + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), + norm_type) + return total_norm + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1/x2) + + +def to_pixel_coords(flow, h1, w1): + flow = ( + torch.stack( + ( + w1 * (flow[..., 0] + 1) / 2, + h1 * (flow[..., 1] + 1) / 2, + ), + axis=-1, + ) + ) + return flow + +def read_config(file_path): + with open(file_path, 'r') as file: + config = yaml.safe_load(file) + return config \ No newline at end of file diff --git a/imcui/third_party/rdd/README.md b/imcui/third_party/rdd/README.md new file mode 100644 index 0000000000000000000000000000000000000000..776d6bc0eca402e939d06c59ea274ebb65e0b63f --- /dev/null +++ b/imcui/third_party/rdd/README.md @@ -0,0 +1,197 @@ +## RDD: Robust Feature Detector and Descriptor using Deformable Transformer (CVPR 2025) +[Gonglin Chen](https://xtcpete.com/) · [Tianwen Fu](https://twfu.me/) · [Haiwei Chen](https://scholar.google.com/citations?user=LVWRssoAAAAJ&hl=en) · [Wenbin Teng](https://wbteng9526.github.io/) · [Hanyuan Xiao](https://corneliushsiao.github.io/index.html) · [Yajie Zhao](https://ict.usc.edu/about-us/leadership/research-leadership/yajie-zhao/) + +[Project Page](https://xtcpete.github.io/rdd/) + +## Table of Contents +- [Updates](#updates) +- [Installation](#installation) +- [Usage](#usage) + - [Inference](#inference) + - [Evaluation](#evaluation) + - [Training](#training) +- [Citation](#citation) +- [License](#license) +- [Acknowledgements](#acknowledgements) + +## Updates + +- SfM reconstruction through [COLMAP](https://github.com/colmap/colmap.git) added. We provide a ready-to-use [notebook](./demo_sfm.ipynb) for a simple example. Code adopted from [hloc](https://github.com/cvg/Hierarchical-Localization.git). + +- Training code and new weights released. + +- We have updated the training code compared to what was described in the paper. In the original setup, the RDD was trained on the MegaDepth and Air-to-Ground datasets by resizing all images to the training resolution. In this release, we retrained RDD on MegaDepth only, using a combination of resizing and cropping, a strategy used by [ALIKE](https://github.com/Shiaoming/ALIKE). This change significantly improves robustness. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MegaDepth-1500MegaDepth-ViewAir-to-Ground
AUC 5°AUC 10°AUC 20°AUC 5°AUC 10°AUC 20°AUC 5°AUC 10°AUC 20°
RDD-v252.468.580.152.067.178.245.858.671.0
RDD-v148.265.278.338.353.165.641.456.067.8
RDD-v2+LG53.369.882.059.074.284.054.869.079.1
RDD-v1+LG52.368.981.854.269.380.355.168.978.9
+ +## Installation + +```bash +git clone --recursive https://github.com/xtcpete/rdd +cd RDD + +# Create conda env +conda create -n rdd python=3.10 pip +conda activate rdd + +# Install CUDA +conda install -c nvidia/label/cuda-11.8.0 cuda-toolkit +# Install torch +pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu118 +# Install all dependencies +pip install -r requirements.txt +# Compile custom operations +cd ./RDD/models/ops +pip install -e . +``` + +We provide the [download link](https://drive.google.com/drive/folders/1QgVaqm4iTUCqbWb7_Fi6mX09EHTId0oA?usp=sharing) to: + - the MegaDepth-1500 test set + - the MegaDepth-View test set + - the Air-to-Ground test set + - 2 pretrained models, RDD and LightGlue for matching RDD + +Create and unzip downloaded test data to the `data` folder. + +Create and add weights to the `weights` folder and you are ready to go. + +## Usage +For your convenience, we provide a ready-to-use [notebook](./demo_matching.ipynb) for some examples. + +### Inference + +```python +from RDD.RDD import build + +RDD_model = build() + +output = RDD_model.extract(torch.randn(1, 3, 480, 640)) +``` + +### Evaluation + +Please note that due to the different GPU architectures and the stochastic nature of RANSAC, you may observe slightly different results; however, they should be very close to those reported in the paper. To reproduce the number in paper, use v1 weights instead. + +Results can be visualized by passing argument --plot + +**MegaDepth-1500** + +```bash +# Sparse matching +python ./benchmarks/mega_1500.py + +# Dense matching +python ./benchmarks/mega_1500.py --method dense + +# LightGlue +python ./benchmarks/mega_1500.py --method lightglue +``` + +**MegaDepth-View** + +```bash +# Sparse matching +python ./benchmarks/mega_view.py + +# Dense matching +python ./benchmarks/mega_view.py --method dense + +# LightGlue +python ./benchmarks/mega_view.py --method lightglue +``` + +**Air-to-Ground** + +```bash +# Sparse matching +python ./benchmarks/air_ground.py + +# Dense matching +python ./benchmarks/air_ground.py --method dense + +# LightGlue +python ./benchmarks/air_ground.py --method lightglue +``` + +### Training + +1. Download MegaDepth dataset using [download.sh](./data/megadepth/download.sh) and megadepth_indices from [LoFTR](https://github.com/zju3dv/LoFTR/blob/master/docs/TRAINING.md#download-datasets). Then the MegaDepth root folder should look like the following: +```bash +./data/megadepth/megadepth_indices # indices +./data/megadepth/depth_undistorted # depth maps +./data/megadepth/Undistorted_SfM # images and poses +./data/megadepth/scene_info # indices for training LightGlue +``` +2. Then you can train RDD in two steps; Descriptor first +```bash +# distributed training with 8 gpus +python -m training.train --ckpt_save_path ./ckpt_descriptor --distributed --batch_size 32 + +# single gpu +python -m training.train --ckpt_save_path ./ckpt_descriptor +``` +and then Detector +```bash +python -m training.train --ckpt_save_path ./ckpt_detector --weights ./ckpt_descriptor/RDD_best.pth --train_detector --training_res 480 +``` + +I am working on recollecting the Air-to-Ground dataset because of licensing issues. + +## Citation +``` +@inproceedings{gonglin2025rdd, + title = {RDD: Robust Feature Detector and Descriptor using Deformable Transformer}, + author = {Chen, Gonglin and Fu, Tianwen and Chen, Haiwei and Teng, Wenbin and Xiao, Hanyuan and Zhao, Yajie}, + booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2025} +} +``` + + +## License +[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE) + +## Acknowledgements + +We thank these great repositories: [ALIKE](https://github.com/Shiaoming/ALIKE), [LoFTR](https://github.com/zju3dv/LoFTR), [DeDoDe](https://github.com/Parskatt/DeDoDe), [XFeat](https://github.com/verlab/accelerated_features), [LightGlue](https://github.com/cvg/LightGlue), [Kornia](https://github.com/kornia/kornia), and [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR), and many other inspiring works in the community. + +LightGlue is trained with [Glue Factory](https://github.com/cvg/glue-factory). + +Supported by the Intelligence Advanced Research Projects Activity (IARPA) via Department of Interior/Interior Business Center (DOI/IBC) contract number 140D0423C0075. The U.S. Government is authorized to reproduce and distribute reprints for governmental purposes notwithstanding any copyright annotation thereon. Disclaimer: The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies or endorsements, either expressed or implied, of IARPA, DOI/IBC, or the U.S. Government. We would like to thank Yayue Chen for her help with visualization. diff --git a/imcui/third_party/rdd/benchmarks/air_ground.py b/imcui/third_party/rdd/benchmarks/air_ground.py new file mode 100644 index 0000000000000000000000000000000000000000..fb129ccdad131de66a3e3ac9e5fa86c9607816f9 --- /dev/null +++ b/imcui/third_party/rdd/benchmarks/air_ground.py @@ -0,0 +1,247 @@ +import sys +sys.path.append(".") +import numpy as np +import torch +from PIL import Image +import tqdm +import cv2 +import argparse +import matplotlib.pyplot as plt +import matplotlib +from RDD.RDD_helper import RDD_helper +from RDD.RDD import build +import os +from benchmarks.utils import pose_auc, angle_error_vec, angle_error_mat, symmetric_epipolar_distance, compute_symmetrical_epipolar_errors, compute_pose_error, compute_relative_pose, estimate_pose, dynamic_alpha + +def make_matching_figure( + img0, img1, mkpts0, mkpts1, color, + kpts0=None, kpts1=None, text=[], dpi=75, path=None): + # draw image pair + assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}' + fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) + axes[0].imshow(img0, cmap='gray') + axes[1].imshow(img1, cmap='gray') + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=1) + + if kpts0 is not None: + assert kpts1 is not None + axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2) + axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2) + + # draw matches + if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: + fig.canvas.draw() + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) + fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) + fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, c=color[i], linewidth=1) + for i in range(len(mkpts0))] + + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4) + axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4) + + # put txts + txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w' + fig.text( + 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, + fontsize=15, va='top', ha='left', color=txt_color) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches='tight', pad_inches=0) + plt.close() + else: + return fig + +def error_colormap(err, thr, alpha=1.0): + assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" + x = 1 - np.clip(err / (thr * 2), 0, 1) + return np.clip( + np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1) + +def _make_evaluation_figure(img0, img1, kpts0, kpts1, epi_errs, e_t, e_R, alpha='dynamic', path=None): + conf_thr = 1e-4 + + img0 = np.array(img0) + img1 = np.array(img1) + + kpts0 = kpts0 + kpts1 = kpts1 + + epi_errs = epi_errs.cpu().numpy() + correct_mask = epi_errs < conf_thr + precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 + n_correct = np.sum(correct_mask) + + # recall might be larger than 1, since the calculation of conf_matrix_gt + # uses groundtruth depths and camera poses, but epipolar distance is used here. + + # matching info + if alpha == 'dynamic': + alpha = dynamic_alpha(len(correct_mask)) + color = error_colormap(epi_errs, conf_thr, alpha=alpha) + + text = [ + f'#Matches {len(kpts0)}', + f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}', + f'e_t: {e_t:.2f} | e_R: {e_R:.2f}', + ] + + # make the figure + figure = make_matching_figure(img0, img1, kpts0, kpts1, + color, text=text, path=path) + return figure + +class AirGroundPoseMNNBenchmark: + def __init__(self, data_root="./data/air_ground", scene_names = None) -> None: + if scene_names is None: + self.scene_names = [ + "indices.npz", + ] + # self.scene_names = ["0022_0.5_0.7.npz",] + else: + self.scene_names = scene_names + self.scenes = [ + np.load(f"{data_root}/{scene}", allow_pickle=True) + for scene in self.scene_names + ] + self.data_root = data_root + + def benchmark(self, model_helper, model_name = None, scale_intrinsics = False, calibrated = True, plot_every_iter=10, plot=False, method='sparse'): + with torch.no_grad(): + data_root = self.data_root + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + thresholds = [5, 10, 20] + for scene_ind in range(len(self.scenes)): + import os + scene_name = os.path.splitext(self.scene_names[scene_ind])[0] + scene = self.scenes[scene_ind] + indices = scene['pair_info'] + idx = 0 + for pair in tqdm.tqdm(indices): + + pairs = pair['pair_names'] + K0 = pair['intrinsic'][0].copy().astype(np.float32) + T0 = pair['pose'][0].copy().astype(np.float32) + R0, t0 = T0[:3, :3], T0[:3, 3] + K1 = pair['intrinsic'][1].copy().astype(np.float32) + T1 = pair['pose'][1].copy().astype(np.float32) + R1, t1 = T1[:3, :3], T1[:3, 3] + R, t = compute_relative_pose(R0, t0, R1, t1) + T0_to_1 = np.concatenate((R,t[:,None]), axis=-1) + im_A_path = f"{data_root}/images/{pairs[0]}" + im_B_path = f"{data_root}/images/{pairs[1]}" + + im_A = cv2.imread(im_A_path) + im_B = cv2.imread(im_B_path) + + if method == 'dense': + kpts0, kpts1, conf = model_helper.match_dense(im_A, im_B, thr=0.01, resize=1600) + elif method == 'lightglue': + kpts0, kpts1, conf = model_helper.match_lg(im_A, im_B, thr=0.01, resize=1600) + elif method == 'sparse': + kpts0, kpts1, conf = model_helper.match(im_A, im_B, thr=0.01, resize=1600) + else: + raise ValueError(f"Invalid method {method}") + + im_A = Image.open(im_A_path) + w0, h0 = im_A.size + im_B = Image.open(im_B_path) + w1, h1 = im_B.size + if scale_intrinsics: + scale0 = 840 / max(w0, h0) + scale1 = 840 / max(w1, h1) + w0, h0 = scale0 * w0, scale0 * h0 + w1, h1 = scale1 * w1, scale1 * h1 + K0, K1 = K0.copy(), K1.copy() + K0[:2] = K0[:2] * scale0 + K1[:2] = K1[:2] * scale1 + + threshold = 0.5 + if calibrated: + norm_threshold = threshold / (np.mean(np.abs(K0[:2, :2])) + np.mean(np.abs(K1[:2, :2]))) + ret = estimate_pose( + kpts0, + kpts1, + K0, + K1, + norm_threshold, + conf=0.99999, + ) + if ret is not None: + R_est, t_est, mask = ret + T0_to_1_est = np.concatenate((R_est, t_est), axis=-1) # + T0_to_1 = np.concatenate((R, t[:,None]), axis=-1) + e_t, e_R = compute_pose_error(T0_to_1_est, R, t) + + epi_errs = compute_symmetrical_epipolar_errors(T0_to_1, kpts0, kpts1, K0, K1) + if scene_ind % plot_every_iter == 0 and plot: + + if not os.path.exists(f'outputs/air_ground/{model_name}_{method}'): + os.mkdir(f'outputs/air_ground/{model_name}_{method}') + name = f'outputs/air_ground/{model_name}_{method}/{scene_name}_{idx}.png' + _make_evaluation_figure(im_A, im_B, kpts0, kpts1, epi_errs, e_t, e_R, path=name) + e_pose = max(e_t, e_R) + + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + idx += 1 + + tot_e_pose = np.array(tot_e_pose) + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + print(f"{model_name} auc: {auc}") + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } + + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Testing script.") + + parser.add_argument("--data_root", type=str, default="./data/air_ground", help="Path to the Air-to-Ground test dataset.") + + parser.add_argument("--weights", type=str, default="./weights/RDD-v2.pth", help="Path to the model checkpoint.") + + parser.add_argument("--plot", action="store_true", help="Whether to plot the results.") + + parser.add_argument("--method", type=str, default="sparse", help="Method for matching.") + + return parser.parse_args() + +if __name__ == "__main__": + args = parse_arguments() + + if not os.path.exists('outputs'): + os.mkdir('outputs') + if not os.path.exists(f'outputs/air_ground'): + os.mkdir(f'outputs/air_ground') + model = build(weights=args.weights) + benchmark = AirGroundPoseMNNBenchmark(data_root=args.data_root) + model.eval() + model_helper = RDD_helper(model) + with torch.no_grad(): + method = args.method + out = benchmark.benchmark(model_helper, model_name='RDD', plot_every_iter=1, plot=args.plot, method=method) + with open(f'outputs/air_ground/RDD_{method}.txt', 'w') as f: + f.write(str(out)) \ No newline at end of file diff --git a/imcui/third_party/rdd/benchmarks/mega_1500.py b/imcui/third_party/rdd/benchmarks/mega_1500.py new file mode 100644 index 0000000000000000000000000000000000000000..7ccec9873d02bc34f51c8703694d679a74f61b99 --- /dev/null +++ b/imcui/third_party/rdd/benchmarks/mega_1500.py @@ -0,0 +1,255 @@ +import sys +sys.path.append(".") +import numpy as np +import torch +from PIL import Image +import tqdm +import cv2 +import argparse +from RDD.RDD_helper import RDD_helper +from RDD.RDD import build +import matplotlib.pyplot as plt +import matplotlib +import os +from benchmarks.utils import pose_auc, angle_error_vec, angle_error_mat, symmetric_epipolar_distance, compute_symmetrical_epipolar_errors, compute_pose_error, compute_relative_pose, estimate_pose, dynamic_alpha + +def make_matching_figure( + img0, img1, mkpts0, mkpts1, color, + kpts0=None, kpts1=None, text=[], dpi=75, path=None): + # draw image pair + assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}' + fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) + axes[0].imshow(img0, cmap='gray') + axes[1].imshow(img1, cmap='gray') + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=1) + + if kpts0 is not None: + assert kpts1 is not None + axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2) + axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2) + + # draw matches + if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: + fig.canvas.draw() + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) + fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) + fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, c=color[i], linewidth=1) + for i in range(len(mkpts0))] + + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4) + axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4) + + # put txts + txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w' + fig.text( + 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, + fontsize=15, va='top', ha='left', color=txt_color) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches='tight', pad_inches=0) + plt.close() + else: + return fig + +def error_colormap(err, thr, alpha=1.0): + assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" + x = 1 - np.clip(err / (thr * 2), 0, 1) + return np.clip( + np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1) + +def _make_evaluation_figure(img0, img1, kpts0, kpts1, epi_errs, e_t, e_R, alpha='dynamic', path=None): + conf_thr = 1e-4 + + img0 = np.array(img0) + img1 = np.array(img1) + + kpts0 = kpts0 + kpts1 = kpts1 + + epi_errs = epi_errs.cpu().numpy() + correct_mask = epi_errs < conf_thr + precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 + n_correct = np.sum(correct_mask) + + # recall might be larger than 1, since the calculation of conf_matrix_gt + # uses groundtruth depths and camera poses, but epipolar distance is used here. + + # matching info + if alpha == 'dynamic': + alpha = dynamic_alpha(len(correct_mask)) + color = error_colormap(epi_errs, conf_thr, alpha=alpha) + + text = [ + f'#Matches {len(kpts0)}', + f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}', + f'e_t: {e_t:.2f} | e_R: {e_R:.2f}', + ] + + # make the figure + figure = make_matching_figure(img0, img1, kpts0, kpts1, + color, text=text, path=path) + return figure + +class MegaDepthPoseMNNBenchmark: + def __init__(self, data_root="./megadepth_test_1500", scene_names = None) -> None: + if scene_names is None: + self.scene_names = [ + "0015_0.1_0.3.npz", + "0015_0.3_0.5.npz", + "0022_0.1_0.3.npz", + "0022_0.3_0.5.npz", + "0022_0.5_0.7.npz", + ] + + else: + self.scene_names = scene_names + self.scenes = [ + np.load(f"{data_root}/{scene}", allow_pickle=True) + for scene in self.scene_names + ] + self.data_root = data_root + + def benchmark(self, model_helper, model_name = None, scale_intrinsics = False, calibrated = True, plot_every_iter=1, plot=False, method='sparse'): + + with torch.no_grad(): + data_root = self.data_root + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + thresholds = [5, 10, 20] + for scene_ind in range(len(self.scenes)): + import os + scene_name = os.path.splitext(self.scene_names[scene_ind])[0] + print(f"Processing {scene_name}") + scene = self.scenes[scene_ind] + pairs = scene["pair_infos"] + intrinsics = scene["intrinsics"] + poses = scene["poses"] + im_paths = scene["image_paths"] + pair_inds = range(len(pairs)) + for pairind in tqdm.tqdm(pair_inds): + idx0, idx1 = pairs[pairind][0] + K0 = intrinsics[idx0].copy() + T0 = poses[idx0].copy() + R0, t0 = T0[:3, :3], T0[:3, 3] + K1 = intrinsics[idx1].copy() + T1 = poses[idx1].copy() + R1, t1 = T1[:3, :3], T1[:3, 3] + R, t = compute_relative_pose(R0, t0, R1, t1) + T0_to_1 = np.concatenate((R,t[:,None]), axis=-1) + im_A_path = f"{data_root}/{im_paths[idx0]}" + im_B_path = f"{data_root}/{im_paths[idx1]}" + + im_A = cv2.imread(im_A_path) + im_B = cv2.imread(im_B_path) + + if method == 'dense': + kpts0, kpts1, conf = model_helper.match_dense(im_A, im_B, thr=0.01, resize=1600) + elif method == 'lightglue': + kpts0, kpts1, conf = model_helper.match_lg(im_A, im_B, thr=0.01, resize=1600) + elif method == 'sparse': + kpts0, kpts1, conf = model_helper.match(im_A, im_B, thr=0.01, resize=1600) + else: + kpts0, kpts1, conf = model_helper.match_3rd_party(im_A, im_B, thr=0.01, resize=1600, model=method) + + im_A = Image.open(im_A_path) + w0, h0 = im_A.size + im_B = Image.open(im_B_path) + w1, h1 = im_B.size + if scale_intrinsics: + scale0 = 840 / max(w0, h0) + scale1 = 840 / max(w1, h1) + w0, h0 = scale0 * w0, scale0 * h0 + w1, h1 = scale1 * w1, scale1 * h1 + K0, K1 = K0.copy(), K1.copy() + K0[:2] = K0[:2] * scale0 + K1[:2] = K1[:2] * scale1 + + + threshold = 0.5 + if calibrated: + norm_threshold = threshold / (np.mean(np.abs(K0[:2, :2])) + np.mean(np.abs(K1[:2, :2]))) + ret = estimate_pose( + kpts0, + kpts1, + K0, + K1, + norm_threshold, + conf=0.99999, + ) + if ret is not None: + R_est, t_est, mask = ret + T0_to_1_est = np.concatenate((R_est, t_est), axis=-1) # + T0_to_1 = np.concatenate((R, t[:,None]), axis=-1) + e_t, e_R = compute_pose_error(T0_to_1_est, R, t) + + epi_errs = compute_symmetrical_epipolar_errors(T0_to_1, kpts0, kpts1, K0, K1) + if scene_ind % plot_every_iter == 0 and plot: + + if not os.path.exists(f'outputs/mega_1500/{model_name}_{method}'): + os.mkdir(f'outputs/mega_1500/{model_name}_{method}') + name = f'outputs/mega_1500/{model_name}_{method}/{scene_name}_{pairind}.png' + _make_evaluation_figure(im_A, im_B, kpts0, kpts1, epi_errs, e_t, e_R, path=name) + e_pose = max(e_t, e_R) + + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + + tot_e_pose = np.array(tot_e_pose) + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + print(f"{model_name} auc: {auc}") + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Testing script.") + + parser.add_argument("--data_root", type=str, default="./data/megadepth_test_1500", help="Path to the MegaDepth dataset.") + + parser.add_argument("--weights", type=str, default="./weights/RDD-v2.pth", help="Path to the model checkpoint.") + + parser.add_argument("--plot", action="store_true", help="Whether to plot the results.") + + parser.add_argument("--method", type=str, default="sparse", help="Method for matching.") + + return parser.parse_args() + +if __name__ == "__main__": + args = parse_arguments() + if not os.path.exists('outputs'): + os.mkdir('outputs') + + if not os.path.exists(f'outputs/mega_1500'): + os.mkdir(f'outputs/mega_1500') + + model = build(weights=args.weights) + benchmark = MegaDepthPoseMNNBenchmark(data_root=args.data_root) + model.eval() + model_helper = RDD_helper(model) + with torch.no_grad(): + method = args.method + out = benchmark.benchmark(model_helper, model_name='RDD', plot_every_iter=1, plot=args.plot, method=method) + with open(f'outputs/mega_1500/RDD_{method}.txt', 'w') as f: + f.write(str(out)) \ No newline at end of file diff --git a/imcui/third_party/rdd/benchmarks/mega_view.py b/imcui/third_party/rdd/benchmarks/mega_view.py new file mode 100644 index 0000000000000000000000000000000000000000..feef900e87f0670cd164b1dac0fe595e938921a1 --- /dev/null +++ b/imcui/third_party/rdd/benchmarks/mega_view.py @@ -0,0 +1,250 @@ +import sys +sys.path.append(".") +import numpy as np +import torch +from PIL import Image +import tqdm +import cv2 +import argparse +import matplotlib.pyplot as plt +import matplotlib +from RDD.RDD import build +from RDD.RDD_helper import RDD_helper +import os +from benchmarks.utils import pose_auc, angle_error_vec, angle_error_mat, symmetric_epipolar_distance, compute_symmetrical_epipolar_errors, compute_pose_error, compute_relative_pose, estimate_pose, dynamic_alpha + +def make_matching_figure( + img0, img1, mkpts0, mkpts1, color, + kpts0=None, kpts1=None, text=[], dpi=75, path=None): + # draw image pair + assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}' + fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) + axes[0].imshow(img0, cmap='gray') + axes[1].imshow(img1, cmap='gray') + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=1) + + if kpts0 is not None: + assert kpts1 is not None + axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2) + axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2) + + # draw matches + if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: + fig.canvas.draw() + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) + fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) + fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, c=color[i], linewidth=1) + for i in range(len(mkpts0))] + + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4) + axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4) + + # put txts + txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w' + fig.text( + 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, + fontsize=15, va='top', ha='left', color=txt_color) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches='tight', pad_inches=0) + plt.close() + else: + return fig + +def error_colormap(err, thr, alpha=1.0): + assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" + x = 1 - np.clip(err / (thr * 2), 0, 1) + return np.clip( + np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1) + +def _make_evaluation_figure(img0, img1, kpts0, kpts1, epi_errs, e_t, e_R, alpha='dynamic', path=None): + conf_thr = 1e-4 + + img0 = np.array(img0) + img1 = np.array(img1) + + kpts0 = kpts0 + kpts1 = kpts1 + + epi_errs = epi_errs.cpu().numpy() + correct_mask = epi_errs < conf_thr + precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 + n_correct = np.sum(correct_mask) + + # recall might be larger than 1, since the calculation of conf_matrix_gt + # uses groundtruth depths and camera poses, but epipolar distance is used here. + + # matching info + if alpha == 'dynamic': + alpha = dynamic_alpha(len(correct_mask)) + color = error_colormap(epi_errs, conf_thr, alpha=alpha) + + text = [ + f'#Matches {len(kpts0)}', + f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}', + f'e_t: {e_t:.2f} | e_R: {e_R:.2f}', + ] + + # make the figure + figure = make_matching_figure(img0, img1, kpts0, kpts1, + color, text=text, path=path) + return figure + +class MegaDepthPoseMNNBenchmark: + def __init__(self, data_root="./megadepth_test_1500", scene_names = None) -> None: + if scene_names is None: + self.scene_names = [ + "hard_indices.npz", + ] + # self.scene_names = ["0022_0.5_0.7.npz",] + else: + self.scene_names = scene_names + self.scenes = [ + np.load(f"{data_root}/{scene}", allow_pickle=True) + for scene in self.scene_names + ] + self.data_root = data_root + + def benchmark(self, model_helper, model_name = None, scale_intrinsics = False, calibrated = True, plot_every_iter=1, plot=False, method='sparse'): + with torch.no_grad(): + data_root = self.data_root + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + thresholds = [5, 10, 20] + for scene_ind in range(len(self.scenes)): + scene_name = os.path.splitext(self.scene_names[scene_ind])[0] + scene = self.scenes[scene_ind] + indices = scene['indices'] + idx = 0 + + for pair in tqdm.tqdm(indices): + + pairs = pair['pair_names'] + K0 = pair['intrisinic'][0].copy().astype(np.float32) + T0 = pair['pose'][0].copy().astype(np.float32) + R0, t0 = T0[:3, :3], T0[:3, 3] + K1 = pair['intrisinic'][1].copy().astype(np.float32) + T1 = pair['pose'][1].copy().astype(np.float32) + R1, t1 = T1[:3, :3], T1[:3, 3] + R, t = compute_relative_pose(R0, t0, R1, t1) + T0_to_1 = np.concatenate((R,t[:,None]), axis=-1) + im_A_path = f"{data_root}/images/{pairs[0]}" + im_B_path = f"{data_root}/images/{pairs[1]}" + + im_A = cv2.imread(im_A_path) + im_B = cv2.imread(im_B_path) + + if method == 'dense': + kpts0, kpts1, conf = model_helper.match_dense(im_A, im_B, thr=0.01, resize=1600) + elif method == 'lightglue': + kpts0, kpts1, conf = model_helper.match_lg(im_A, im_B, thr=0.01, resize=1600) + elif method == 'sparse': + kpts0, kpts1, conf = model_helper.match(im_A, im_B, thr=0.01, resize=1600) + else: + raise ValueError(f"Invalid method {method}") + + im_A = Image.open(im_A_path) + w0, h0 = im_A.size + im_B = Image.open(im_B_path) + w1, h1 = im_B.size + + if scale_intrinsics: + scale0 = 840 / max(w0, h0) + scale1 = 840 / max(w1, h1) + w0, h0 = scale0 * w0, scale0 * h0 + w1, h1 = scale1 * w1, scale1 * h1 + K0, K1 = K0.copy(), K1.copy() + K0[:2] = K0[:2] * scale0 + K1[:2] = K1[:2] * scale1 + + threshold = 0.5 + if calibrated: + norm_threshold = threshold / (np.mean(np.abs(K0[:2, :2])) + np.mean(np.abs(K1[:2, :2]))) + ret = estimate_pose( + kpts0, + kpts1, + K0, + K1, + norm_threshold, + conf=0.99999, + ) + if ret is not None: + R_est, t_est, mask = ret + T0_to_1_est = np.concatenate((R_est, t_est), axis=-1) # + T0_to_1 = np.concatenate((R, t[:,None]), axis=-1) + e_t, e_R = compute_pose_error(T0_to_1_est, R, t) + + epi_errs = compute_symmetrical_epipolar_errors(T0_to_1, kpts0, kpts1, K0, K1) + if scene_ind % plot_every_iter == 0 and plot: + + if not os.path.exists(f'outputs/mega_view/{model_name}_{method}'): + os.mkdir(f'outputs/mega_view/{model_name}_{method}') + name = f'outputs/mega_view/{model_name}_{method}/{scene_name}_{idx}.png' + _make_evaluation_figure(im_A, im_B, kpts0, kpts1, epi_errs, e_t, e_R, path=name) + e_pose = max(e_t, e_R) + + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + idx += 1 + + tot_e_pose = np.array(tot_e_pose) + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + print(f"{model_name} auc: {auc}") + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } + + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Testing script.") + + parser.add_argument("--data_root", type=str, default="./data/megadepth_view", help="Path to the MegaDepth dataset.") + + parser.add_argument("--weights", type=str, default="./weights/RDD-v2.pth", help="Path to the model checkpoint.") + + parser.add_argument("--plot", action="store_true", help="Whether to plot the results.") + + parser.add_argument("--method", type=str, default="sparse", help="Method for matching.") + + return parser.parse_args() + +if __name__ == "__main__": + args = parse_arguments() + if not os.path.exists('outputs'): + os.mkdir('outputs') + + if not os.path.exists(f'outputs/mega_view'): + os.mkdir(f'outputs/mega_view') + model = build(weights=args.weights) + benchmark = MegaDepthPoseMNNBenchmark(data_root=args.data_root) + model.eval() + model_helper = RDD_helper(model) + with torch.no_grad(): + method = args.method + out = benchmark.benchmark(model_helper, model_name='RDD', plot_every_iter=1, plot=args.plot, method=method) + with open(f'outputs/mega_view/RDD_{method}.txt', 'w') as f: + f.write(str(out)) + + diff --git a/imcui/third_party/rdd/benchmarks/utils.py b/imcui/third_party/rdd/benchmarks/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f72b7df20c7750e0cdb0812b4a3bde600433b2ca --- /dev/null +++ b/imcui/third_party/rdd/benchmarks/utils.py @@ -0,0 +1,112 @@ +import numpy as np +import torch +from kornia.geometry.epipolar import numeric +from kornia.geometry.conversions import convert_points_to_homogeneous +import cv2 + +def pose_auc(errors, thresholds): + sort_idx = np.argsort(errors) + errors = np.array(errors.copy())[sort_idx] + recall = (np.arange(len(errors)) + 1) / len(errors) + errors = np.r_[0.0, errors] + recall = np.r_[0.0, recall] + aucs = [] + for t in thresholds: + last_index = np.searchsorted(errors, t) + r = np.r_[recall[:last_index], recall[last_index - 1]] + e = np.r_[errors[:last_index], t] + aucs.append(np.trapz(r, x=e) / t) + return aucs + +def angle_error_vec(v1, v2): + n = np.linalg.norm(v1) * np.linalg.norm(v2) + return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0))) + +def angle_error_mat(R1, R2): + cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 + cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds + return np.rad2deg(np.abs(np.arccos(cos))) + +def symmetric_epipolar_distance(pts0, pts1, E, K0, K1): + """Squared symmetric epipolar distance. + This can be seen as a biased estimation of the reprojection error. + Args: + pts0 (torch.Tensor): [N, 2] + E (torch.Tensor): [3, 3] + """ + pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] + pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] + pts0 = convert_points_to_homogeneous(pts0) + pts1 = convert_points_to_homogeneous(pts1) + + Ep0 = pts0 @ E.T # [N, 3] + p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,] + Etp1 = pts1 @ E # [N, 3] + + d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N + return d + +def compute_symmetrical_epipolar_errors(T_0to1, pts0, pts1, K0, K1, device='cuda'): + """ + Update: + data (dict):{"epi_errs": [M]} + """ + pts0 = torch.tensor(pts0, device=device) + pts1 = torch.tensor(pts1, device=device) + K0 = torch.tensor(K0, device=device) + K1 = torch.tensor(K1, device=device) + T_0to1 = torch.tensor(T_0to1, device=device) + Tx = numeric.cross_product_matrix(T_0to1[:3, 3]) + E_mat = Tx @ T_0to1[:3, :3] + + epi_err = symmetric_epipolar_distance(pts0, pts1, E_mat, K0, K1) + return epi_err + +def compute_pose_error(T_0to1, R, t): + R_gt = T_0to1[:3, :3] + t_gt = T_0to1[:3, 3] + error_t = angle_error_vec(t.squeeze(), t_gt) + error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation + error_R = angle_error_mat(R, R_gt) + return error_t, error_R + +def compute_relative_pose(R1, t1, R2, t2): + rots = R2 @ (R1.T) + trans = -rots @ t1 + t2 + return rots, trans + +def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): + if len(kpts0) < 5: + return None + K0inv = np.linalg.inv(K0[:2,:2]) + K1inv = np.linalg.inv(K1[:2,:2]) + + kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T + kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T + E, mask = cv2.findEssentialMat( + kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf + ) + + ret = None + if E is not None: + best_num_inliers = 0 + + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t, mask.ravel() > 0) + return ret + +def dynamic_alpha(n_matches, + milestones=[0, 300, 1000, 2000], + alphas=[1.0, 0.8, 0.4, 0.2]): + if n_matches == 0: + return 1.0 + ranges = list(zip(alphas, alphas[1:] + [None])) + loc = bisect.bisect_right(milestones, n_matches) - 1 + _range = ranges[loc] + if _range[1] is None: + return _range[0] + return _range[1] + (milestones[loc + 1] - n_matches) / ( + milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1]) \ No newline at end of file diff --git a/imcui/third_party/rdd/configs/default.yaml b/imcui/third_party/rdd/configs/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7e740015c77e6040a7e637ca27a15c6a6c7caf0f --- /dev/null +++ b/imcui/third_party/rdd/configs/default.yaml @@ -0,0 +1,19 @@ +activation: relu +block_dims: +- 8 +- 16 +- 32 +- 64 +d_model: 256 +detection_threshold: 0.1 +device: cuda +dim_feedforward: 1024 +dropout: 0.1 +enc_n_points: 8 +hidden_dim: 256 +lr_backbone: 2.0e-05 +nhead: 8 +num_encoder_layers: 4 +num_feature_levels: 5 +top_k: 4096 +train_detector: False \ No newline at end of file diff --git a/imcui/third_party/rdd/data/megadepth/download.sh b/imcui/third_party/rdd/data/megadepth/download.sh new file mode 100644 index 0000000000000000000000000000000000000000..d116020a5a477a65b366990a62a0cb34f7fc0660 --- /dev/null +++ b/imcui/third_party/rdd/data/megadepth/download.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# Base URL for the files +url_base="https://cvg-data.inf.ethz.ch/megadepth/" + +# Define the tar files and their corresponding destination directories +declare -A file_mappings=( + ["Undistorted_SfM.tar.gz"]="Undistorted_SfM/" + ["depth_undistorted.tar.gz"]="depth_undistorted/" + ["scene_info.tar.gz"]="scene_info/" +) + +# Download, extract, and move files +for tar_name in "${!file_mappings[@]}"; do + echo "downloading ${tar_name}" + + out_name="${file_mappings[$tar_name]}" + + # Full path of the tar.gz file + tar_path="./${tar_name}" + + # Download the file + wget "${url_base}${tar_name}" -O "$tar_path" + + # Check if download was successful + if [ $? -ne 0 ]; then + echo "Failed to download $tar_name" + exit 1 + fi + + # Extract the tar.gz file + tar -xzf "$tar_path" -C "./" + + # Remove the tar.gz file after extraction + rm "$tar_path" + + # Move the extracted folder to the desired location + extracted_folder="${tar_name%%.*}" # Get the folder name (remove .tar.gz) + mv "$tmp_dir/$extracted_folder" "$tmp_dir/$out_name" +done \ No newline at end of file diff --git a/imcui/third_party/rdd/demo_matching.ipynb b/imcui/third_party/rdd/demo_matching.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..51c77aacc2ac71cc378298194b6a7a9439297e55 --- /dev/null +++ b/imcui/third_party/rdd/demo_matching.ipynb @@ -0,0 +1,230 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "GzFYZYcT9oyb" + }, + "source": [ + "# RDD matching example (sparse, semi-dense and lightglue)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "97Mbt4a89z3Z" + }, + "source": [ + "## Initialize RDD" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from RDD.RDD import build\n", + "from RDD.RDD_helper import RDD_helper\n", + "from matplotlib import pyplot as plt\n", + "from time import time\n", + "\n", + "RDD_model = build(weights='./weights/RDD-v2.pth')\n", + "RDD_model.eval()\n", + "RDD = RDD_helper(RDD_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import cv2\n", + "\n", + "def draw_matches(ref_points, dst_points, img0, img1):\n", + " \n", + " # Prepare keypoints and matches for drawMatches function\n", + " keypoints0 = [cv2.KeyPoint(p[0], p[1], 1000) for p in ref_points]\n", + " keypoints1 = [cv2.KeyPoint(p[0], p[1], 1000) for p in dst_points]\n", + " matches = [cv2.DMatch(i,i,0) for i in range(len(ref_points))]\n", + "\n", + " # Draw inlier matches\n", + " img_matches = cv2.drawMatches(img0, keypoints0, img1, keypoints1, matches, None,\n", + " matchColor=(0, 255, 0), flags=2)\n", + "\n", + " return img_matches\n", + "\n", + "\n", + "def draw_points(points, img):\n", + " for p in points:\n", + " cv2.circle(img, (int(p[0]), int(p[1])), 2, (0, 255, 0), -1)\n", + " \n", + " return img\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b83vE-Dt-cTC" + }, + "source": [ + "## Matching example - Sparse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Load some example images\n", + "im0 = cv2.imread('./assets/image0.jpg')\n", + "im1 = cv2.imread('./assets/image1.jpg')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 431 + }, + "id": "8qm_cdIq9-jy", + "outputId": "ebd99a35-807d-4684-f43b-4f1b0a022c66" + }, + "outputs": [], + "source": [ + "start = time()\n", + "mkpts_0, mkpts_1, conf = RDD.match(im0, im1, resize=1024)\n", + "print(f\"Found {len(mkpts_0)} matches in {time()-start:.2f} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "canvas = draw_matches(mkpts_0, mkpts_1, im0, im1)\n", + "plt.figure(figsize=(12,12))\n", + "plt.imshow(canvas[..., ::-1]), plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Matching example - Semi-Dense" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "start = time()\n", + "mkpts_0, mkpts_1, conf = RDD.match_dense(im0, im1, resize=1024)\n", + "print(f\"Found {len(mkpts_0)} matches in {time()-start:.2f} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "canvas = draw_matches(mkpts_0, mkpts_1, im0, im1)\n", + "plt.figure(figsize=(12,12))\n", + "plt.imshow(canvas[..., ::-1]), plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Matching example - LightGlue" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "start = time()\n", + "mkpts_0, mkpts_1, conf = RDD.match_lg(im0, im1, resize=1024)\n", + "print(f\"Found {len(mkpts_0)} matches in {time()-start:.2f} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "canvas = draw_matches(mkpts_0, mkpts_1, im0, im1)\n", + "plt.figure(figsize=(12,12))\n", + "plt.imshow(canvas[..., ::-1]), plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Matching example - Using differnt detector + RDD descriptor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "start = time()\n", + "mkpts_0, mkpts_1, conf = RDD.match_3rd_party(im0, im1, resize=1024, model='aliked')\n", + "print(f\"Found {len(mkpts_0)} matches in {time()-start:.2f} seconds\")\n", + "\n", + "# take a look at folder third_party, RDD/RDD.py and RDD/RDD_helper.py \n", + "# if you want to configure your own detector" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "KM1KQaj9-oOv" + ], + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "RDD", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.18" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/imcui/third_party/rdd/demo_sfm.ipynb b/imcui/third_party/rdd/demo_sfm.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..32a5d0035df3ffea9ba24ec84afa8c99d9b51b15 --- /dev/null +++ b/imcui/third_party/rdd/demo_sfm.ipynb @@ -0,0 +1,120 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# RDD reconstruction example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sfm import match_rdd, extract_rdd\n", + "from hloc import (\n", + " extract_features,\n", + " reconstruction,\n", + " visualization,\n", + " pairs_from_retrieval,\n", + " pairs_from_exhaustive,\n", + ")\n", + "from pathlib import Path\n", + "import os\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "images_dir = Path('./assets/mapping')\n", + "device = torch.cuda.is_available()\n", + "images = [image for image in os.listdir(images_dir) if image.endswith('.jpg') or image.endswith('.png')]\n", + "outputs = Path('./outputs/reconstruction')\n", + "if not outputs.exists():\n", + " outputs.mkdir(parents=True)\n", + "sfm_pairs = outputs / 'sfm_pairs.txt'\n", + "retrieval_conf = extract_features.confs[\"netvlad\"]\n", + "feature_conf = extract_rdd.confs[\"rdd\"]\n", + "matcher_conf = match_rdd.confs[\"rdd+lightglue\"]\n", + "exhaustive_if_less = 30\n", + "num_matched = 20" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# image_retrieval\n", + "if len(images) < exhaustive_if_less:\n", + " pairs_from_exhaustive.main(sfm_pairs, images)\n", + "else:\n", + " retrieval_path = extract_features.main(retrieval_conf, images_dir, outputs)\n", + " pairs_from_retrieval.main(retrieval_path, sfm_pairs, num_matched=num_matched)\n", + " \n", + "# feature_extraction\n", + "feature_path = extract_rdd.main(feature_conf, images_dir, outputs)\n", + "# matching\n", + "match_path = match_rdd.main(matcher_conf, sfm_pairs, feature_conf['output'], outputs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# reconstruction\n", + "image_options = {}\n", + "mapper_options = {}\n", + "model = reconstruction.main(outputs, images_dir, sfm_pairs, feature_path, \n", + " match_path, verbose=True, camera_mode='PER_IMAGE', image_options=image_options, mapper_options=mapper_options,\n", + " min_match_score = 0.2, skip_geometric_verification=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(model.summary())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "RDD", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/imcui/third_party/rdd/requirements.txt b/imcui/third_party/rdd/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3a9a071347d0dde044f5b0a166438c607a88f448 --- /dev/null +++ b/imcui/third_party/rdd/requirements.txt @@ -0,0 +1,8 @@ +opencv_python +ninja +poselib +tqdm +kornia +matplotlib +pyyaml +git+https://github.com/cvg/Hierarchical-Localization.git \ No newline at end of file diff --git a/imcui/third_party/rdd/sfm/extract_rdd.py b/imcui/third_party/rdd/sfm/extract_rdd.py new file mode 100644 index 0000000000000000000000000000000000000000..4ba3f6fc789507559d412b73cb47329a3549d88c --- /dev/null +++ b/imcui/third_party/rdd/sfm/extract_rdd.py @@ -0,0 +1,145 @@ +""" +Modified from hloc +https://github.com/cvg/Hierarchical-Localization.git +""" +import argparse +import collections.abc as collections +import glob +import pprint +from pathlib import Path +from types import SimpleNamespace +from typing import Dict, List, Optional, Union +import cv2 +import h5py +import numpy as np +import PIL.Image +import torch +from tqdm import tqdm +from hloc.extract_features import ImageDataset +from hloc import logger +from hloc.utils.base_model import dynamic_load +from hloc.utils.io import list_h5_names, read_image +from hloc.utils.parsers import parse_image_lists +from RDD.RDD import build +from RDD.utils import read_config + +confs = { + 'rdd': { + "output": "feats-rdd-n4096", + "model": { + 'config_path': './configs/default.yaml', + 'weights': './weights/RDD-v2.pth', + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1024, + "resize_force": True, + } + } +} + +@torch.no_grad() +def main( + conf: Dict, + image_dir: Path, + export_dir: Optional[Path] = None, + as_half: bool = True, + image_list: Optional[Union[Path, List[str]]] = None, + feature_path: Optional[Path] = None, + overwrite: bool = False, +) -> Path: + logger.info( + "Extracting local features with configuration:" f"\n{pprint.pformat(conf)}" + ) + + dataset = ImageDataset(image_dir, conf["preprocessing"], image_list) + if feature_path is None: + feature_path = Path(export_dir, conf["output"] + ".h5") + feature_path.parent.mkdir(exist_ok=True, parents=True) + skip_names = set( + list_h5_names(feature_path) if feature_path.exists() and not overwrite else () + ) + dataset.names = [n for n in dataset.names if n not in skip_names] + if len(dataset.names) == 0: + logger.info("Skipping the extraction.") + return feature_path + + device = "cuda" if torch.cuda.is_available() else "cpu" + config = read_config(conf["model"]["config_path"]) + config['device'] = device + model = build(config, conf["model"]["weights"]) + model.eval() + loader = torch.utils.data.DataLoader( + dataset, num_workers=1, shuffle=False, pin_memory=True + ) + for idx, data in enumerate(tqdm(loader)): + name = dataset.names[idx] + features = model.extract(data["image"]) + + pred = { + "keypoints": [f["keypoints"] for f in features], + "keypoint_scores": [f["scores"] for f in features], + "descriptors": [f["descriptors"].t() for f in features], + } + + pred = {k: v[0].cpu().numpy() for k, v in pred.items()} + + pred["image_size"] = original_size = data["original_size"][0].numpy() + if "keypoints" in pred: + size = np.array(data["image"].shape[-2:][::-1]) + scales = (original_size / size).astype(np.float32) + pred["keypoints"] = (pred["keypoints"] + 0.5) * scales[None] - 0.5 + if "scales" in pred: + pred["scales"] *= scales.mean() + # add keypoint uncertainties scaled to the original resolution + uncertainty = getattr(model, "detection_noise", 1) * scales.mean() + + if as_half: + for k in pred: + dt = pred[k].dtype + if (dt == np.float32) and (dt != np.float16): + pred[k] = pred[k].astype(np.float16) + + with h5py.File(str(feature_path), "a", libver="latest") as fd: + try: + if name in fd: + del fd[name] + grp = fd.create_group(name) + for k, v in pred.items(): + grp.create_dataset(k, data=v) + if "keypoints" in pred: + grp["keypoints"].attrs["uncertainty"] = uncertainty + except OSError as error: + if "No space left on device" in error.args[0]: + logger.error( + "Out of disk space: storing features on disk can take " + "significant space, did you enable the as_half flag?" + ) + del grp, fd[name] + raise error + + del pred + + logger.info("Finished exporting features.") + return feature_path + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--image_dir", type=Path, required=True) + parser.add_argument("--export_dir", type=Path, required=True) + parser.add_argument( + "--conf", type=str, default="rdd", choices=list(confs.keys()) + ) + parser.add_argument("--as_half", action="store_true") + parser.add_argument("--image_list", type=Path) + parser.add_argument("--feature_path", type=Path) + args = parser.parse_args() + main( + confs[args.conf], + args.image_dir, + args.export_dir, + args.as_half, + args.image_list, + args.feature_path, + ) \ No newline at end of file diff --git a/imcui/third_party/rdd/sfm/match_rdd.py b/imcui/third_party/rdd/sfm/match_rdd.py new file mode 100644 index 0000000000000000000000000000000000000000..c38924829428cbd341667701ffec0dc83c06bd0b --- /dev/null +++ b/imcui/third_party/rdd/sfm/match_rdd.py @@ -0,0 +1,255 @@ +import argparse +import pprint +from functools import partial +from pathlib import Path +from queue import Queue +from threading import Thread +from typing import Dict, List, Optional, Tuple, Union +from torch import nn +import h5py +import torch +from tqdm import tqdm + +from hloc import logger +from hloc.utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval + +from RDD.matchers import LightGlue + +class Matcher(nn.Module): + default_conf = { + "features": "rdd", + "depth_confidence": 0.95, + "width_confidence": 0.99, + } + + required_inputs = [ + "image0", + "keypoints0", + "descriptors0", + "image1", + "keypoints1", + "descriptors1", + ] + + def __init__(self, conf): + super().__init__() + self.net = LightGlue(conf.pop("features"), **conf) + + def forward(self, data): + """Check the data and call the _forward method of the child model.""" + for key in self.required_inputs: + assert key in data, "Missing key {} in data".format(key) + return self._forward(data) + + def _forward(self, data): + data["descriptors0"] = data["descriptors0"].transpose(-1, -2) + data["descriptors1"] = data["descriptors1"].transpose(-1, -2) + + return self.net( + { + "image0": {k[:-1]: v for k, v in data.items() if k[-1] == "0"}, + "image1": {k[:-1]: v for k, v in data.items() if k[-1] == "1"}, + } + ) + + +""" +A set of standard configurations that can be directly selected from the command +line using their name. Each is a dictionary with the following entries: + - output: the name of the match file that will be generated. + - model: the model configuration, as passed to a feature matcher. +""" +confs = { + "rdd+lightglue": { + "output": "matches-rdd-lightglue", + "model": { + "name": "lightglue", + "features": "rdd", + }, + } +} + + +class WorkQueue: + def __init__(self, work_fn, num_threads=1): + self.queue = Queue(num_threads) + self.threads = [ + Thread(target=self.thread_fn, args=(work_fn,)) for _ in range(num_threads) + ] + for thread in self.threads: + thread.start() + + def join(self): + for thread in self.threads: + self.queue.put(None) + for thread in self.threads: + thread.join() + + def thread_fn(self, work_fn): + item = self.queue.get() + while item is not None: + work_fn(item) + item = self.queue.get() + + def put(self, data): + self.queue.put(data) + + +class FeaturePairsDataset(torch.utils.data.Dataset): + def __init__(self, pairs, feature_path_q, feature_path_r): + self.pairs = pairs + self.feature_path_q = feature_path_q + self.feature_path_r = feature_path_r + + def __getitem__(self, idx): + name0, name1 = self.pairs[idx] + data = {} + with h5py.File(self.feature_path_q, "r") as fd: + grp = fd[name0] + for k, v in grp.items(): + data[k + "0"] = torch.from_numpy(v.__array__()).float() + # some matchers might expect an image but only use its size + data["image0"] = torch.empty((1,) + tuple(grp["image_size"])[::-1]) + with h5py.File(self.feature_path_r, "r") as fd: + grp = fd[name1] + for k, v in grp.items(): + data[k + "1"] = torch.from_numpy(v.__array__()).float() + data["image1"] = torch.empty((1,) + tuple(grp["image_size"])[::-1]) + return data + + def __len__(self): + return len(self.pairs) + + +def writer_fn(inp, match_path): + pair, pred = inp + with h5py.File(str(match_path), "a", libver="latest") as fd: + if pair in fd: + del fd[pair] + grp = fd.create_group(pair) + matches = pred["matches0"][0].cpu().short().numpy() + grp.create_dataset("matches0", data=matches) + if "matching_scores0" in pred: + scores = pred["matching_scores0"][0].cpu().half().numpy() + grp.create_dataset("matching_scores0", data=scores) + + +def main( + conf: Dict, + pairs: Path, + features: Union[Path, str], + export_dir: Optional[Path] = None, + matches: Optional[Path] = None, + features_ref: Optional[Path] = None, + overwrite: bool = False, + device: str = "cpu", +) -> Path: + if isinstance(features, Path) or Path(features).exists(): + features_q = features + if matches is None: + raise ValueError( + "Either provide both features and matches as Path" " or both as names." + ) + else: + if export_dir is None: + raise ValueError( + "Provide an export_dir if features is not" f" a file path: {features}." + ) + features_q = Path(export_dir, features + ".h5") + if matches is None: + matches = Path(export_dir, f'{features}_{conf["output"]}_{pairs.stem}.h5') + + if features_ref is None: + features_ref = features_q + match_from_paths(conf, pairs, matches, features_q, features_ref, overwrite) + + return matches + + +def find_unique_new_pairs(pairs_all: List[Tuple[str]], match_path: Path = None): + """Avoid to recompute duplicates to save time.""" + pairs = set() + for i, j in pairs_all: + if (j, i) not in pairs: + pairs.add((i, j)) + pairs = list(pairs) + if match_path is not None and match_path.exists(): + with h5py.File(str(match_path), "r", libver="latest") as fd: + pairs_filtered = [] + for i, j in pairs: + if ( + names_to_pair(i, j) in fd + or names_to_pair(j, i) in fd + or names_to_pair_old(i, j) in fd + or names_to_pair_old(j, i) in fd + ): + continue + pairs_filtered.append((i, j)) + return pairs_filtered + return pairs + + +@torch.no_grad() +def match_from_paths( + conf: Dict, + pairs_path: Path, + match_path: Path, + feature_path_q: Path, + feature_path_ref: Path, + overwrite: bool = False, +) -> Path: + logger.info( + "Matching local features with configuration:" f"\n{pprint.pformat(conf)}" + ) + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + print(f"Using device: {device}") + if not feature_path_q.exists(): + raise FileNotFoundError(f"Query feature file {feature_path_q}.") + if not feature_path_ref.exists(): + raise FileNotFoundError(f"Reference feature file {feature_path_ref}.") + match_path.parent.mkdir(exist_ok=True, parents=True) + + assert pairs_path.exists(), pairs_path + pairs = parse_retrieval(pairs_path) + pairs = [(q, r) for q, rs in pairs.items() for r in rs] + pairs = find_unique_new_pairs(pairs, None if overwrite else match_path) + if len(pairs) == 0: + logger.info("Skipping the matching.") + return + + model = Matcher(conf["model"]) + model.eval() + model.to(device) + + dataset = FeaturePairsDataset(pairs, feature_path_q, feature_path_ref) + loader = torch.utils.data.DataLoader( + dataset, num_workers=5, batch_size=1, shuffle=False, pin_memory=True + ) + writer_queue = WorkQueue(partial(writer_fn, match_path=match_path), 5) + + for idx, data in enumerate(tqdm(loader, smoothing=0.1)): + data = { + k: v if k.startswith("image") else v.to(device, non_blocking=True) + for k, v in data.items() + } + pred = model(data) + + # if matches are less than 25 then skip + pair = names_to_pair(*pairs[idx]) + writer_queue.put((pair, pred)) + writer_queue.join() + logger.info("Finished exporting matches.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pairs", type=Path, required=True) + parser.add_argument("--export_dir", type=Path) + parser.add_argument("--features", type=str, default="feats-superpoint-n4096-r1024") + parser.add_argument("--matches", type=Path) + parser.add_argument( + "--conf", type=str, default="superglue", choices=list(confs.keys()) + ) + args = parser.parse_args() + main(confs[args.conf], args.pairs, args.features, args.export_dir) \ No newline at end of file diff --git a/imcui/third_party/rdd/training/losses/descriptor_loss.py b/imcui/third_party/rdd/training/losses/descriptor_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..3e6177fcecb1187a38376273bc7ff533c73180a8 --- /dev/null +++ b/imcui/third_party/rdd/training/losses/descriptor_loss.py @@ -0,0 +1,78 @@ +import torch +import torch.nn.functional as F + +from training.utils import * +from torch import nn + +class DescriptorLoss(nn.Module): + def __init__(self, inv_temp = 20, dual_softmax_weight = 1, heatmap_weight = 1): + super().__init__() + self.inv_temp = inv_temp + self.dual_softmax_weight = dual_softmax_weight + self.heatmap_weight = heatmap_weight + + def forward(self, m1, m2, h1, h2, pts1, pts2): + loss_ds = dual_softmax_loss(m1, m2, temp=20, normalize=True) * self.dual_softmax_weight + + loss_h1, acc1 = heatmap_loss(h1, pts1) + loss_h2, acc2 = heatmap_loss(h2, pts2) + loss_h = (loss_h1 + loss_h2) / 2 * self.heatmap_weight + + acc_kp = 0.5 * (acc1 + acc2) + + return loss_ds, loss_h, acc_kp + +def dual_softmax_loss(X, Y, temp = 1, normalize = False): + if X.size() != Y.size() or X.dim() != 2 or Y.dim() != 2: + raise RuntimeError('Error: X and Y shapes must match and be 2D matrices') + + if normalize: + X = X/X.norm(dim=-1,keepdim=True) + Y = Y/Y.norm(dim=-1,keepdim=True) + + dist_mat = (X @ Y.t()) * temp + + P = dist_mat.softmax(dim = -2) * dist_mat.softmax(dim= -1) + + conf_gt = torch.eye(len(X), device = X.device) + pos_mask, neg_mask = conf_gt == 1, conf_gt == 0 + + conf_gt = torch.clamp(conf_gt, 1e-6, 1-1e-6) + + # focal loss + alpha = 0.25 + gamma = 2 + pos_conf = P[pos_mask] + loss_pos = - alpha * torch.pow(1 - pos_conf, gamma) * pos_conf.log() + + return 5 * loss_pos.mean() + +def heatmap_loss(kpts, pts): + C, H, W = kpts.shape + + with torch.no_grad(): + + labels = torch.zeros((1, H, W), dtype=torch.long, device=kpts.device) + labels[:, (pts[:,1]).long(), (pts[:,0]).long()] = 1 + + kpts = kpts.view(-1) + labels = labels.view(-1) + + # Negative (background) loss to push predictions towards zero + # neg_conf = kpts[neg_mask] + + # Combine positive and negative losses + + BCE_loss = F.binary_cross_entropy(kpts, labels.float(), reduction='none') + pt = torch.exp(-BCE_loss) + F_loss = 0.25 * (1 - pt) ** 2* BCE_loss + + with torch.no_grad(): + predictions = (kpts > 0.5) + true_positives = ((predictions == 1) & (labels == 1)).sum().item() + false_positives = ((predictions == 1) & (labels == 0)).sum().item() + + # Calculate Precision + precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 + + return F_loss.mean(), precision diff --git a/imcui/third_party/rdd/training/losses/detector_loss.py b/imcui/third_party/rdd/training/losses/detector_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a594be7c91dc8cc70d04d548cee1a4e3a33f58a0 --- /dev/null +++ b/imcui/third_party/rdd/training/losses/detector_loss.py @@ -0,0 +1,499 @@ +import torch +import torch.nn.functional as F + +from torch import nn +import cv2 +import numpy as np +from copy import deepcopy +from RDD.dataset.megadepth.utils import warp + +def plot_keypoints(image, kpts, radius=2, color=(255, 0, 0)): + image = image.cpu().detach().numpy() if isinstance(image, torch.Tensor) else image + kpts = kpts.cpu().detach().numpy() if isinstance(kpts, torch.Tensor) else kpts + + if image.dtype is not np.dtype('uint8'): + image = image * 255 + image = image.astype(np.uint8) + + if len(image.shape) == 2 or image.shape[2] == 1: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + + out = np.ascontiguousarray(deepcopy(image)) + kpts = np.round(kpts).astype(int) + + for kpt in kpts: + y0, x0 = kpt + cv2.drawMarker(out, (x0, y0), color, cv2.MARKER_CROSS, radius) + + # cv2.circle(out, (x0, y0), radius, color, -1, lineType=cv2.LINE_4) + return out + +class DetectorLoss(nn.Module): + def __init__(self, temperature = 0.1, scores_th = 0.1, peaky_weight = 0.5, reprojection_weight = 1, scoremap_weight = 0.5): + super().__init__() + self.temperature = temperature + self.scores_th = scores_th + self.peaky_weight = peaky_weight + self.reprojection_weight = reprojection_weight + self.scoremap_weight = scoremap_weight + + self.PeakyLoss = PeakyLoss(scores_th = scores_th) + self.ReprojectionLocLoss = ReprojectionLocLoss(scores_th = scores_th) + self.ScoreMapRepLoss = ScoreMapRepLoss(temperature = temperature) + + def forward(self, correspondences, pred0_with_rand, pred1_with_rand): + loss_peaky0 = self.PeakyLoss(pred0_with_rand) + loss_peaky1 = self.PeakyLoss(pred1_with_rand) + loss_peaky = (loss_peaky0 + loss_peaky1) / 2. + + loss_reprojection = self.ReprojectionLocLoss(pred0_with_rand, pred1_with_rand, correspondences) + + loss_score_map_rp = self.ScoreMapRepLoss(pred0_with_rand, pred1_with_rand, correspondences) + + loss_kp = loss_peaky * self.peaky_weight + loss_reprojection * self.reprojection_weight + loss_score_map_rp * self.scoremap_weight + + return loss_kp + +class PeakyLoss(object): + """ PeakyLoss to avoid an uniform score map """ + + def __init__(self, scores_th: float = 0.1): + super().__init__() + self.scores_th = scores_th + + def __call__(self, pred): + b, c, h, w = pred['scores_map'].shape + loss_mean = 0 + CNT = 0 + + for idx in range(b): + n_original = len(pred['score_dispersity'][idx]) + scores_kpts = pred['scores'][idx][:n_original] + valid = scores_kpts > self.scores_th + loss_peaky = pred['score_dispersity'][idx][valid] + + loss_mean = loss_mean + loss_peaky.sum() + CNT = CNT + len(loss_peaky) + + loss_mean = loss_mean / CNT if CNT != 0 else pred['scores_map'].new_tensor(0) + assert not torch.isnan(loss_mean) + return loss_mean + + +class ReprojectionLocLoss(object): + """ + Reprojection location errors of keypoints to train repeatable detector. + """ + + def __init__(self, norm: int = 1, scores_th: float = 0.1): + super().__init__() + self.norm = norm + self.scores_th = scores_th + + def __call__(self, pred0, pred1, correspondences): + b, c, h, w = pred0['scores_map'].shape + loss_mean = 0 + CNT = 0 + for idx in range(b): + if correspondences[idx]['correspondence0'] is None: + continue + + if self.norm == 2: + dist = correspondences[idx]['dist'] + elif self.norm == 1: + dist = correspondences[idx]['dist_l1'] + else: + raise TypeError('No such norm in correspondence.') + + ids0_d = correspondences[idx]['ids0_d'] + ids1_d = correspondences[idx]['ids1_d'] + + scores0 = correspondences[idx]['scores0'].detach()[ids0_d] + scores1 = correspondences[idx]['scores1'].detach()[ids1_d] + valid = (scores0 > self.scores_th) * (scores1 > self.scores_th) + reprojection_errors = dist[ids0_d, ids1_d][valid] + + loss_mean = loss_mean + reprojection_errors.sum() + CNT = CNT + len(reprojection_errors) + + loss_mean = loss_mean / CNT if CNT != 0 else correspondences[0]['dist'].new_tensor(0) + + assert not torch.isnan(loss_mean) + return loss_mean + + +def local_similarity(descriptor_map, descriptors, kpts_wh, radius): + """ + :param descriptor_map: CxHxW + :param descriptors: NxC + :param kpts_wh: Nx2 (W,H) + :return: + """ + _, h, w = descriptor_map.shape + ksize = 2 * radius + 1 + + descriptor_map_unflod = torch.nn.functional.unfold(descriptor_map.unsqueeze(0), + kernel_size=(ksize, ksize), + padding=(radius, radius)) + descriptor_map_unflod = descriptor_map_unflod[0].t().reshape(h * w, -1, ksize * ksize) + # find the correspondence patch + kpts_wh_long = kpts_wh.detach().long() + patch_ids = kpts_wh_long[:, 0] + kpts_wh_long[:, 1] * h + desc_patches = descriptor_map_unflod[patch_ids].permute(0, 2, 1).detach() # N_kpts x s*s x 128 + + local_sim = torch.einsum('nsd,nd->ns', desc_patches, descriptors) + local_sim_sort = torch.sort(local_sim, dim=1, descending=True).values + local_sim_sort_mean = local_sim_sort[:, 4:].mean(dim=1) # 4 is safe radius for bilinear interplation + + return local_sim_sort_mean + + +class ScoreMapRepLoss(object): + """ Scoremap repetability""" + + def __init__(self, temperature: float = 0.1): + super().__init__() + self.temperature = temperature + self.radius = 2 + + def __call__(self, pred0, pred1, correspondences): + b, c, h, w = pred0['scores_map'].shape + wh = pred0['keypoints'][0].new_tensor([[w - 1, h - 1]]) + loss_mean = 0 + CNT = 0 + + for idx in range(b): + if correspondences[idx]['correspondence0'] is None: + continue + + scores_map0 = pred0['scores_map'][idx] + scores_map1 = pred1['scores_map'][idx] + kpts01 = correspondences[idx]['kpts01'] + kpts10 = correspondences[idx]['kpts10'] # valid warped keypoints + + # ===================== + scores_kpts10 = torch.nn.functional.grid_sample(scores_map0.unsqueeze(0), kpts10.view(1, 1, -1, 2), + mode='bilinear', align_corners=True)[0, 0, 0, :] + scores_kpts01 = torch.nn.functional.grid_sample(scores_map1.unsqueeze(0), kpts01.view(1, 1, -1, 2), + mode='bilinear', align_corners=True)[0, 0, 0, :] + + s0 = scores_kpts01 * correspondences[idx]['scores0'] # repeatability + s1 = scores_kpts10 * correspondences[idx]['scores1'] # repeatability + + # ===================== repetability + similarity_map_01 = correspondences[idx]['similarity_map_01_valid'] + similarity_map_10 = correspondences[idx]['similarity_map_10_valid'] + + pmf01 = ((similarity_map_01.detach() - 1) / self.temperature).exp() + pmf10 = ((similarity_map_10.detach() - 1) / self.temperature).exp() + kpts01 = kpts01.detach() + kpts10 = kpts10.detach() + + pmf01_kpts = torch.nn.functional.grid_sample(pmf01.unsqueeze(0), kpts01.view(1, 1, -1, 2), + mode='bilinear', align_corners=True)[0, :, 0, :] + pmf10_kpts = torch.nn.functional.grid_sample(pmf10.unsqueeze(0), kpts10.view(1, 1, -1, 2), + mode='bilinear', align_corners=True)[0, :, 0, :] + repetability01 = torch.diag(pmf01_kpts) + repetability10 = torch.diag(pmf10_kpts) + + # ===================== reliability + # ids0, ids1 = correspondences[idx]['ids0'], correspondences[idx]['ids1'] + # descriptor_map0 = pred0['descriptor_map'][idx].detach() + # descriptor_map1 = pred1['descriptor_map'][idx].detach() + # descriptors0 = pred0['descriptors'][idx][ids0].detach() + # descriptors1 = pred1['descriptors'][idx][ids1].detach() + # kpts0 = pred0['keypoints'][idx][ids0].detach() + # kpts1 = pred1['keypoints'][idx][ids1].detach() + # kpts0_wh = (kpts0 / 2 + 0.5) * wh + # kpts1_wh = (kpts1 / 2 + 0.5) * wh + # ls0 = local_similarity(descriptor_map0, descriptors0, kpts0_wh, self.radius) + # ls1 = local_similarity(descriptor_map1, descriptors1, kpts1_wh, self.radius) + # reliability0 = 1 - ((ls0 - 1) / self.temperature).exp() + # reliability1 = 1 - ((ls1 - 1) / self.temperature).exp() + + fs0 = repetability01 # * reliability0 + fs1 = repetability10 # * reliability1 + + if s0.sum() != 0: + loss01 = (1 - fs0) * s0 * len(s0) / s0.sum() + loss_mean = loss_mean + loss01.sum() + CNT = CNT + len(loss01) + if s1.sum() != 0: + loss10 = (1 - fs1) * s1 * len(s1) / s1.sum() + loss_mean = loss_mean + loss10.sum() + CNT = CNT + len(loss10) + + loss_mean = loss_mean / CNT if CNT != 0 else pred0['scores_map'].new_tensor(0) + assert not torch.isnan(loss_mean) + return loss_mean + + + +#+++++++++++++++++++++++++++++++++++++++++++++++++++Taken from ALIKE+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +def compute_keypoints_distance(kpts0, kpts1, p=2, debug=False): + """ + Args: + kpts0: torch.tensor [M,2] + kpts1: torch.tensor [N,2] + p: (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm + + Returns: + dist, torch.tensor [N,M] + """ + dist = kpts0[:, None, :] - kpts1[None, :, :] # [M,N,2] + dist = torch.norm(dist, p=p, dim=2) # [M,N] + return dist + +def mutual_argmax(value, mask=None, as_tuple=True): + """ + Args: + value: MxN + mask: MxN + + Returns: + + """ + value = value - value.min() # convert to non-negative tensor + if mask is not None: + value = value * mask + + max0 = value.max(dim=1, keepdim=True) # the col index the max value in each row + max1 = value.max(dim=0, keepdim=True) + + valid_max0 = value == max0[0] + valid_max1 = value == max1[0] + + mutual = valid_max0 * valid_max1 + if mask is not None: + mutual = mutual * mask + + return mutual.nonzero(as_tuple=as_tuple) + + +def mutual_argmin(value, mask=None): + return mutual_argmax(-value, mask) + +def compute_correspondence(model, pred0, pred1, batch, radius = 2, rand=True, train_gt_th = 5, debug=False): + b, c, h, w = pred0['scores_map'].shape + wh = pred0['scores_map'][0].new_tensor([[w - 1, h - 1]]) + + pred0_with_rand = pred0 + pred1_with_rand = pred1 + pred0_with_rand['scores'] = [] + pred1_with_rand['scores'] = [] + pred0_with_rand['descriptors'] = [] + pred1_with_rand['descriptors'] = [] + pred0_with_rand['num_det'] = [] + pred1_with_rand['num_det'] = [] + + kps, score_dispersity, scores = model.softdetect.detect_keypoints(pred0['scores_map']) + + pred0_with_rand['keypoints'] = kps + pred0_with_rand['score_dispersity'] = score_dispersity + + + kps, score_dispersity, scores = model.softdetect.detect_keypoints(pred1['scores_map']) + pred1_with_rand['keypoints'] = kps + pred1_with_rand['score_dispersity'] = score_dispersity + + correspondences = [] + for idx in range(b): + # =========================== prepare keypoints + kpts0, kpts1 = pred0['keypoints'][idx], pred1['keypoints'][idx] # (x,y), shape: Nx2 + + # additional random keypoints + if rand: + rand0 = torch.rand(len(kpts0), 2, device=kpts0.device) * 2 - 1 # -1~1 + rand1 = torch.rand(len(kpts1), 2, device=kpts1.device) * 2 - 1 # -1~1 + kpts0 = torch.cat([kpts0, rand0]) + kpts1 = torch.cat([kpts1, rand1]) + + pred0_with_rand['keypoints'][idx] = kpts0 + pred1_with_rand['keypoints'][idx] = kpts1 + + scores_map0 = pred0['scores_map'][idx] + scores_map1 = pred1['scores_map'][idx] + scores_kpts0 = torch.nn.functional.grid_sample(scores_map0.unsqueeze(0), kpts0.view(1, 1, -1, 2), + mode='bilinear', align_corners=True).squeeze() + scores_kpts1 = torch.nn.functional.grid_sample(scores_map1.unsqueeze(0), kpts1.view(1, 1, -1, 2), + mode='bilinear', align_corners=True).squeeze() + + kpts0_wh_ = (kpts0 / 2 + 0.5) * wh # N0x2, (w,h) + kpts1_wh_ = (kpts1 / 2 + 0.5) * wh # N1x2, (w,h) + + # ========================= nms + dist = compute_keypoints_distance(kpts0_wh_.detach(), kpts0_wh_.detach()) + local_mask = dist < radius + valid_cnt = torch.sum(local_mask, dim=1) + indices_need_nms = torch.where(valid_cnt > 1)[0] + + for i in indices_need_nms: + if valid_cnt[i] > 0: + kpt_indices = torch.where(local_mask[i])[0] + scs_max_idx = scores_kpts0[kpt_indices].argmax() + + tmp_mask = kpt_indices.new_ones(len(kpt_indices)).bool() + tmp_mask[scs_max_idx] = False + suppressed_indices = kpt_indices[tmp_mask] + + valid_cnt[suppressed_indices] = 0 + + valid_mask = valid_cnt > 0 + kpts0_wh = kpts0_wh_[valid_mask] + kpts0 = kpts0[valid_mask] + scores_kpts0 = scores_kpts0[valid_mask] + pred0_with_rand['keypoints'][idx] = kpts0 + + valid_mask = valid_mask[:len(pred0_with_rand['score_dispersity'][idx])] + pred0_with_rand['score_dispersity'][idx] = pred0_with_rand['score_dispersity'][idx][valid_mask] + pred0_with_rand['num_det'].append(valid_mask.sum()) + + dist = compute_keypoints_distance(kpts1_wh_.detach(), kpts1_wh_.detach()) + local_mask = dist < radius + valid_cnt = torch.sum(local_mask, dim=1) + indices_need_nms = torch.where(valid_cnt > 1)[0] + for i in indices_need_nms: + if valid_cnt[i] > 0: + kpt_indices = torch.where(local_mask[i])[0] + scs_max_idx = scores_kpts1[kpt_indices].argmax() + + tmp_mask = kpt_indices.new_ones(len(kpt_indices)).bool() + tmp_mask[scs_max_idx] = False + suppressed_indices = kpt_indices[tmp_mask] + + valid_cnt[suppressed_indices] = 0 + + valid_mask = valid_cnt > 0 + kpts1_wh = kpts1_wh_[valid_mask] + kpts1 = kpts1[valid_mask] + scores_kpts1 = scores_kpts1[valid_mask] + pred1_with_rand['keypoints'][idx] = kpts1 + + valid_mask = valid_mask[:len(pred1_with_rand['score_dispersity'][idx])] + pred1_with_rand['score_dispersity'][idx] = pred1_with_rand['score_dispersity'][idx][valid_mask] + pred1_with_rand['num_det'].append(valid_mask.sum()) + + # del dist, local_mask, valid_cnt, indices_need_nms, scs_max_idx, tmp_mask, suppressed_indices, valid_mask + # torch.cuda.empty_cache() + # ========================= nms + + pred0_with_rand['scores'].append(scores_kpts0) + pred1_with_rand['scores'].append(scores_kpts1) + + descriptor_map0, descriptor_map1 = pred0['descriptor_map'][idx], pred1['descriptor_map'][idx] + descriptor_map0 = F.normalize(descriptor_map0, dim=0) + descriptor_map1 = F.normalize(descriptor_map1, dim=0) + + desc0 = torch.nn.functional.grid_sample(descriptor_map0.unsqueeze(0), kpts0.view(1, 1, -1, 2), + mode='bilinear', align_corners=True)[0, :, 0, :].t() + desc1 = torch.nn.functional.grid_sample(descriptor_map1.unsqueeze(0), kpts1.view(1, 1, -1, 2), + mode='bilinear', align_corners=True)[0, :, 0, :].t() + desc0 = F.normalize(desc0, dim=-1) + desc1 = F.normalize(desc1, dim=-1) + + pred0_with_rand['descriptors'].append(desc0) + pred1_with_rand['descriptors'].append(desc1) + + # =========================== prepare warp parameters + warp01_params = {} + for k, v in batch['warp01_params'].items(): + warp01_params[k] = v[idx] + warp10_params = {} + for k, v in batch['warp10_params'].items(): + warp10_params[k] = v[idx] + + # =========================== warp keypoints across images + try: + kpts0_wh, kpts01_wh, ids0, ids0_out = warp(kpts0_wh, warp01_params) + kpts1_wh, kpts10_wh, ids1, ids1_out = warp(kpts1_wh, warp10_params) + except: + correspondences.append({'correspondence0': None, 'correspondence1': None, + 'dist': kpts0_wh.new_tensor(0), + }) + continue + + if debug: + from training.utils import save_image_in_actual_size + + image0 = batch['image0'][idx].cpu().detach().numpy().transpose(1, 2, 0) + image1 = batch['image1'][idx].cpu().detach().numpy().transpose(1, 2, 0) + + p0 = kpts0_wh[:, [1, 0]].cpu().detach().numpy() + img_kpts0 = plot_keypoints(image0, p0, radius=5, color=(255, 0, 0)) + # display_image_in_actual_size(img_kpts0) + + p1 = kpts1_wh[:, [1, 0]].cpu().detach().numpy() + img_kpts1 = plot_keypoints(image1, p1, radius=5, color=(255, 0, 0)) + # display_image_in_actual_size(img_kpts1) + + p01 = kpts01_wh[:, [1, 0]].cpu().detach().numpy() + img_kpts01 = plot_keypoints(img_kpts1, p01, radius=5, color=(0, 255, 0)) + save_image_in_actual_size(img_kpts01, name='kpts01.png') + + p10 = kpts10_wh[:, [1, 0]].cpu().detach().numpy() + img_kpts10 = plot_keypoints(img_kpts0, p10, radius=5, color=(0, 255, 0)) + save_image_in_actual_size(img_kpts10, name='kpts10.png') + + # ============================= compute reprojection error + dist01 = compute_keypoints_distance(kpts0_wh, kpts10_wh) + dist10 = compute_keypoints_distance(kpts1_wh, kpts01_wh) + + dist_l2 = (dist01 + dist10.t()) / 2. + # find mutual correspondences by calculating the distance + # between keypoints (I1) and warpped keypoints (I2->I1) + mutual_min_indices = mutual_argmin(dist_l2) + + dist_mutual_min = dist_l2[mutual_min_indices] + valid_dist_mutual_min = dist_mutual_min.detach() < train_gt_th + + ids0_d = mutual_min_indices[0][valid_dist_mutual_min] + ids1_d = mutual_min_indices[1][valid_dist_mutual_min] + + correspondence0 = ids0[ids0_d] + correspondence1 = ids1[ids1_d] + + # L1 distance + dist01_l1 = compute_keypoints_distance(kpts0_wh, kpts10_wh, p=1) + dist10_l1 = compute_keypoints_distance(kpts1_wh, kpts01_wh, p=1) + + dist_l1 = (dist01_l1 + dist10_l1.t()) / 2. + + # =========================== compute cross image descriptor similarity_map + similarity_map_01 = (desc0 @ descriptor_map1.reshape(h*w, 256).t()) * 20 + similarity_map_01 = similarity_map_01.softmax(dim = -2) * similarity_map_01.softmax(dim= -1) + similarity_map_01 = similarity_map_01.reshape(desc0.shape[0], h, w) + similarity_map_01 = torch.clamp(similarity_map_01, 1e-6, 1-1e-6) + + similarity_map_10 = (desc1 @ descriptor_map0.reshape(h*w, 256).t()) * 20 + similarity_map_10 = similarity_map_10.softmax(dim = -2) * similarity_map_10.softmax(dim= -1) + similarity_map_10 = similarity_map_10.reshape(desc1.shape[0], h, w) + similarity_map_10 = torch.clamp(similarity_map_10, 1e-6, 1-1e-6) + + similarity_map_01_valid = similarity_map_01[ids0] # valid descriptors + similarity_map_10_valid = similarity_map_10[ids1] + + kpts01 = 2 * kpts01_wh.detach() / wh - 1 # N0x2, (x,y), [-1,1] + kpts10 = 2 * kpts10_wh.detach() / wh - 1 # N0x2, (x,y), [-1,1] + + correspondences.append({'correspondence0': correspondence0, # indices of matched kpts0 in all kpts + 'correspondence1': correspondence1, # indices of matched kpts1 in all kpts + 'scores0': scores_kpts0[ids0], + 'scores1': scores_kpts1[ids1], + 'kpts01': kpts01, 'kpts10': kpts10, # warped valid kpts + 'ids0': ids0, 'ids1': ids1, # valid indices of kpts0 and kpts1 + 'ids0_out': ids0_out, 'ids1_out': ids1_out, + 'ids0_d': ids0_d, 'ids1_d': ids1_d, # match indices of valid kpts0 and kpts1 + 'dist_l1': dist_l1, # cross distance matrix of valid kpts using L1 norm + 'dist': dist_l2, # cross distance matrix of valid kpts using L2 norm + 'similarity_map_01': similarity_map_01, # all + 'similarity_map_10': similarity_map_10, # all + 'similarity_map_01_valid': similarity_map_01_valid, # valid + 'similarity_map_10_valid': similarity_map_10_valid, # valid + }) + + return correspondences, pred0_with_rand, pred1_with_rand + + +class EmptyTensorError(Exception): + pass \ No newline at end of file diff --git a/imcui/third_party/rdd/training/train.py b/imcui/third_party/rdd/training/train.py new file mode 100644 index 0000000000000000000000000000000000000000..f517318733a516824327506bcb924ab25b52965f --- /dev/null +++ b/imcui/third_party/rdd/training/train.py @@ -0,0 +1,448 @@ +import argparse +import os +import time +import sys +import glob +from pathlib import Path +from RDD.RDD_helper import RDD_helper +import torch.distributed + +def parse_arguments(): + parser = argparse.ArgumentParser(description="XFeat training script.") + + parser.add_argument('--megadepth_root_path', type=str, default='./data/megadepth', + help='Path to the MegaDepth dataset root directory.') + parser.add_argument('--test_data_root', type=str, default='./data/megadepth_test_1500', + help='Path to the MegaDepth test dataset root directory.') + parser.add_argument('--ckpt_save_path', type=str, required=True, + help='Path to save the checkpoints.') + parser.add_argument('--model_name', type=str, default='RDD', + help='Name of the model to save.') + parser.add_argument('--air_ground_root_path', type=str, default='./data/air_ground_data_2/AirGround') + parser.add_argument('--batch_size', type=int, default=4, + help='Batch size for training. Default is 4.') + parser.add_argument('--lr', type=float, default=1e-4, + help='Learning rate. Default is 0.0001.') + parser.add_argument('--gamma_steplr', type=float, default=0.5, + help='Gamma value for StepLR scheduler. Default is 0.5.') + parser.add_argument('--training_res', type=int, + default=800, help='Training resolution as width,height. Default is 800 for training descriptor.') + parser.add_argument('--save_ckpt_every', type=int, default=500, + help='Save checkpoints every N steps. Default is 500.') + parser.add_argument('--test_every_iter', type=int, default=2000, + help='Save checkpoints every N steps. Default is 2000.') + parser.add_argument('--weights', type=str, default=None,) + parser.add_argument('--num_encoder_layers', type=int, default=4) + parser.add_argument('--enc_n_points', type=int, default=8) + parser.add_argument('--num_feature_levels', type=int, default=5) + parser.add_argument('--train_detector', action='store_true', default=False) + parser.add_argument('--epochs', type=int, default=20) + parser.add_argument('--distributed', action='store_true', default=False) + parser.add_argument('--config_path', type=str, default='./configs/default.yaml') + args = parser.parse_args() + + return args + +args = parse_arguments() + +import torch +from torch import optim +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +import numpy as np +from RDD.RDD import build +from training.utils import * +from training.losses import * +from benchmarks.mega_1500 import MegaDepthPoseMNNBenchmark +from RDD.dataset.megadepth.megadepth import MegaDepthDataset +from RDD.dataset.megadepth import megadepth_warper +from torch.utils.data import Dataset, DataLoader, DistributedSampler, RandomSampler, WeightedRandomSampler +from training.losses.detector_loss import compute_correspondence, DetectorLoss +from training.losses.descriptor_loss import DescriptorLoss +import tqdm +from torch.optim.lr_scheduler import MultiStepLR, StepLR +from datetime import timedelta +from RDD.utils import read_config +torch.autograd.set_detect_anomaly(True) +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.deterministic = True + +class Trainer(): + """ + Class for training XFeat with default params as described in the paper. + We use a blend of MegaDepth (labeled) pairs with synthetically warped images (self-supervised). + The major bottleneck is to keep loading huge megadepth h5 files from disk, + the network training itself is quite fast. + """ + + def __init__(self, rank, args=None): + config = read_config(args.config_path) + + config['num_encoder_layers'] = args.num_encoder_layers + config['enc_n_points'] = args.enc_n_points + config['num_feature_levels'] = args.num_feature_levels + config['train_detector'] = args.train_detector + config['weights'] = args.weights + + # distributed training + if args.distributed: + print(f"Training in distributed mode with {args.n_gpus} GPUs") + assert torch.cuda.is_available() + device = rank + + torch.distributed.init_process_group( + backend="nccl", + world_size=args.n_gpus, + rank=device, + init_method="file://" + str(args.lock_file), + timeout=timedelta(seconds=2000) + ) + torch.cuda.set_device(device) + + # adjust batch size and num of workers since these are per GPU + batch_size = int(args.batch_size / args.n_gpus) + self.n_gpus = args.n_gpus + else: + device = "cuda" if torch.cuda.is_available() else "cpu" + batch_size = args.batch_size + print(f"Using device {device}") + + self.seed = 0 + self.set_seed(self.seed) + self.training_res = args.training_res + self.dev = device + config['device'] = device + model = build(config) + + self.rank = rank + + if args.weights is not None: + print('Loading weights from ', args.weights) + model.load_state_dict(torch.load(args.weights, map_location='cpu')) + + if args.distributed: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + self.model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True + ) + else: + self.model = model.to(device) + + self.saved_ckpts = [] + self.best = -1.0 + self.best_loss = 1e6 + self.fine_weight = 1.0 + self.dual_softmax_weight = 1.0 + self.heatmaps_weight = 1.0 + #Setup optimizer + self.batch_size = batch_size + self.epochs = args.epochs + self.opt = optim.AdamW(filter(lambda x: x.requires_grad, self.model.parameters()) , lr = args.lr, weight_decay=1e-4) + + # losses + if args.train_detector: + self.DetectorLoss = DetectorLoss(temperature=0.1, scores_th=0.1) + else: + self.DescriptorLoss = DescriptorLoss(inv_temp=20, dual_softmax_weight=1, heatmap_weight=1) + + self.benchmark = MegaDepthPoseMNNBenchmark(data_root=args.test_data_root) + + ##################### MEGADEPTH INIT ########################## + + TRAIN_BASE_PATH = f"{args.megadepth_root_path}/megadepth_indices" + print('Loading MegaDepth dataset from ', TRAIN_BASE_PATH) + TRAINVAL_DATA_SOURCE = args.megadepth_root_path + self.TRAINVAL_DATA_SOURCE = TRAINVAL_DATA_SOURCE + TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7" + self.TRAIN_NPZ_ROOT = TRAIN_NPZ_ROOT + npz_paths = glob.glob(TRAIN_NPZ_ROOT + '/*.npz')[:] + self.npz_paths = npz_paths + self.epoch = 0 + self.create_data_loader() + + ##################### MEGADEPTH INIT END ####################### + + os.makedirs(args.ckpt_save_path, exist_ok=True) + os.makedirs(args.ckpt_save_path / 'logdir', exist_ok=True) + + self.save_ckpt_every = args.save_ckpt_every + self.ckpt_save_path = args.ckpt_save_path + if rank == 0: + self.writer = SummaryWriter(str(self.ckpt_save_path) + f'/logdir/{args.model_name}_' + time.strftime("%Y_%m_%d-%H_%M_%S")) + else: + self.writer = None + self.model_name = args.model_name + + if args.distributed: + self.scheduler = MultiStepLR(self.opt, milestones=[2, 4, 8, 16], gamma=args.gamma_steplr) + else: + self.scheduler = StepLR(self.opt, step_size=args.test_every_iter, gamma=args.gamma_steplr) + + def set_seed(self, seed): + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + def create_data_loader(self): + # Create sampler + + + if not args.train_detector: + mega_crop = torch.utils.data.ConcatDataset( [MegaDepthDataset(root = self.TRAINVAL_DATA_SOURCE, + npz_path = path, min_overlap_score=0.01, max_overlap_score=0.7, image_size=self.training_res, num_per_scene=200, gray=False, crop_or_scale='crop') for path in self.npz_paths] ) + mega_scale = torch.utils.data.ConcatDataset( [MegaDepthDataset(root = self.TRAINVAL_DATA_SOURCE, + npz_path = path, min_overlap_score=0.01, max_overlap_score=0.7, image_size=self.training_res, num_per_scene=200, gray=False, crop_or_scale='scale') for path in self.npz_paths] ) + combined_dataset = torch.utils.data.ConcatDataset([mega_crop, mega_scale]) + + else: + + mega_crop = torch.utils.data.ConcatDataset( [MegaDepthDataset(root = self.TRAINVAL_DATA_SOURCE, + npz_path = path, min_overlap_score=0.1, max_overlap_score=0.8, image_size=self.training_res, num_per_scene=100, gray=False, crop_or_scale='crop') for path in self.npz_paths] ) + mega_scale = torch.utils.data.ConcatDataset( [MegaDepthDataset(root = self.TRAINVAL_DATA_SOURCE, + npz_path = path, min_overlap_score=0.1, max_overlap_score=0.8, image_size=self.training_res, num_per_scene=100, gray=False, crop_or_scale='scale') for path in self.npz_paths] ) + combined_dataset = torch.utils.data.ConcatDataset([mega_crop, mega_scale]) + + # Create sampler + if args.distributed: + sampler = DistributedSampler(combined_dataset, rank=self.rank, num_replicas=self.n_gpus) + else: + # Create sampler + sampler = RandomSampler(combined_dataset) + + # Create single DataLoader with combined dataset + self.data_loader = DataLoader(combined_dataset, + batch_size=self.batch_size, + sampler=sampler, + num_workers=4, + pin_memory=True) + + def validate(self, total_steps): + + with torch.no_grad(): + + + if args.train_detector: + method = 'sparse' + else: + method = 'aliked' + + if args.distributed: + self.model.module.eval() + model_helper = RDD_helper(self.model.module) + test_out = self.benchmark.benchmark(model_helper, model_name='experiment', plot_every_iter=1, plot=False, method=method) + else: + self.model.eval() + model_helper = RDD_helper(self.model) + test_out = self.benchmark.benchmark(model_helper, model_name='experiment', plot_every_iter=1, plot=False, method=method) + + auc5 = test_out['auc_5'] + auc10 = test_out['auc_10'] + auc20 = test_out['auc_20'] + if self.rank == 0: + self.writer.add_scalar('Accuracy/auc5', auc5, total_steps) + self.writer.add_scalar('Accuracy/auc10', auc10, total_steps) + self.writer.add_scalar('Accuracy/auc20', auc20, total_steps) + if auc5 > self.best: + self.best = auc5 + if args.distributed: + torch.save(self.model.module.state_dict(), str(self.ckpt_save_path) + f'/{self.model_name}_best.pth') + else: + torch.save(self.model.state_dict(), str(self.ckpt_save_path) + f'/{self.model_name}_best.pth') + + self.model.train() + + + def _inference(self, d): + if d is not None: + for k in d.keys(): + if isinstance(d[k], torch.Tensor): + d[k] = d[k].to(self.dev) + p1, p2 = d['image0'], d['image1'] + + if not args.train_detector: + positives_md_coarse = megadepth_warper.spvs_coarse(d, self.stride) + + with torch.no_grad(): + p1 = p1 ; p2 = p2 + if not args.train_detector: + positives_c = positives_md_coarse + + + # Check if batch is corrupted with too few correspondences + is_corrupted = False + if not args.train_detector: + for p in positives_c: + if len(p) < 30: + is_corrupted = True + + if is_corrupted: + return None, None, None, None + + # Forward pass + + feats1, scores_map1, hmap1 = self.model(p1) + feats2, scores_map2, hmap2 = self.model(p2) + + if args.train_detector: + + # move all tensors on batch to GPU + for k in d.keys(): + if isinstance(d[k], torch.Tensor): + d[k] = d[k].to(self.dev) + elif isinstance(d[k], dict): + for k2 in d[k].keys(): + if isinstance(d[k][k2], torch.Tensor): + d[k][k2] = d[k][k2].to(self.dev) + + # Get positive correspondencies + pred0 = {'descriptor_map': F.interpolate(feats1, size=scores_map1.shape[-2:], mode='bilinear', align_corners=True), 'scores_map': scores_map1 } + pred1 = {'descriptor_map': F.interpolate(feats2, size=scores_map2.shape[-2:], mode='bilinear', align_corners=True), 'scores_map': scores_map2 } + if args.distributed: + correspondences, pred0_with_rand, pred1_with_rand = compute_correspondence(self.model.module, pred0, pred1, d, debug=True) + else: + correspondences, pred0_with_rand, pred1_with_rand = compute_correspondence(self.model, pred0, pred1, d, debug=False) + + loss_kp = self.DetectorLoss(correspondences, pred0_with_rand, pred1_with_rand) + + loss = loss_kp + acc_coarse, acc_kp, nb_coarse = 0, 0, 0 + else: + + loss_items = [] + acc_coarse_items = [] + acc_kp_items = [] + + for b in range(len(positives_c)): + + if len(positives_c[b]) > 10000: + positives = positives_c[b][torch.randperm(len(positives_c[b]))[:10000]] + else: + positives = positives_c[b] + # Get positive correspondencies + pts1, pts2 = positives[:, :2], positives[:, 2:] + + h1 = hmap1[b, :, :, :] + h2 = hmap2[b, :, :, :] + + m1 = feats1[b, :, pts1[:,1].long(), pts1[:,0].long()].permute(1,0) + m2 = feats2[b, :, pts2[:,1].long(), pts2[:,0].long()].permute(1,0) + # Compute losses + loss_ds, loss_h, acc_kp = self.DescriptorLoss(m1, m2, h1, h2, pts1, pts2) + + loss_items.append(loss_ds.unsqueeze(0)) + + acc_coarse = check_accuracy1(m1, m2) + acc_kp_items.append(acc_kp) + acc_coarse_items.append(acc_coarse) + + nb_coarse = len(m1) + loss = loss_kp if args.train_detector else torch.cat(loss_items, -1).mean() + acc_coarse = sum(acc_coarse_items) / len(acc_coarse_items) + acc_kp = sum(acc_kp_items) / len(acc_kp_items) + + return loss, acc_coarse, acc_kp, nb_coarse + + def train(self): + + self.model.train() + self.stride = 4 if args.num_feature_levels == 5 else 8 + total_steps = 0 + + for epoch in range(self.epochs): + + if args.distributed: + self.data_loader.sampler.set_epoch(epoch) + pbar = tqdm.tqdm(total=len(self.data_loader), desc=f"Epoch {epoch+1}/{args.epochs}") if self.rank == 0 else None + + for i, d in enumerate(self.data_loader): + + loss, acc_coarse, acc_kp, nb_coarse = self._inference(d) + + if loss is None: + continue + + # Compute Backward Pass + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.) + self.opt.step() + self.opt.zero_grad() + + if (total_steps + 1) % self.save_ckpt_every == 0 and self.rank == 0: + print('saving iter ', total_steps + 1) + if args.distributed: + torch.save(self.model.module.state_dict(), str(self.ckpt_save_path) + f'/{self.model_name}_{total_steps + 1}.pth') + else: + torch.save(self.model.state_dict(), str(self.ckpt_save_path) + f'/{self.model_name}_{total_steps + 1}.pth') + self.saved_ckpts.append(total_steps + 1) + if len(self.saved_ckpts) > 5: + os.remove(str(self.ckpt_save_path) + f'/{self.model_name}_{self.saved_ckpts[0]}.pth') + self.saved_ckpts = self.saved_ckpts[1:] + + if args.distributed: + torch.distributed.barrier() + + if (total_steps+1) % args.test_every_iter == 0: + self.validate(total_steps) + + if pbar is not None: + + if args.train_detector: + pbar.set_description( 'Loss: {:.4f} '.format(loss.item()) ) + else: + pbar.set_description( 'Loss: {:.4f} acc_coarse {:.3f} acc_kp: {:.3f} #matches_c: {:d}'.format( + loss.item(), acc_coarse, acc_kp, nb_coarse) ) + pbar.update(1) + + # Log metrics + if self.rank == 0: + self.writer.add_scalar('Loss/total', loss.item(), total_steps) + self.writer.add_scalar('Accuracy/coarse_mdepth', acc_coarse, total_steps) + self.writer.add_scalar('Count/matches_coarse', nb_coarse, total_steps) + + if not args.distributed: + self.scheduler.step() + total_steps = total_steps + 1 + + self.validate(total_steps) + if self.rank == 0: + print('Epoch ', epoch, ' done.') + print('Creating new data loader with seed ', self.seed) + self.seed = self.seed + 1 + self.set_seed(self.seed) + self.scheduler.step() + self.epoch = self.epoch + 1 + self.create_data_loader() + +def main_worker(rank, args): + trainer = Trainer( + rank=rank, + args=args + ) + + # The most fun part + trainer.train() + +if __name__ == '__main__': + if args.distributed: + import torch.multiprocessing as mp + mp.set_start_method('spawn', force=True) + + if not Path(args.ckpt_save_path).exists(): + os.makedirs(args.ckpt_save_path) + + args.ckpt_save_path = Path(args.ckpt_save_path).resolve() + + if args.distributed: + args.n_gpus = torch.cuda.device_count() + args.lock_file = Path(args.ckpt_save_path) / "distributed_lock" + if args.lock_file.exists(): + args.lock_file.unlink() + + # Each process gets its own rank and dataset + torch.multiprocessing.spawn( + main_worker, nprocs=args.n_gpus, args=(args,) + ) + else: + main_worker(0, args) \ No newline at end of file diff --git a/imcui/third_party/rdd/training/utils.py b/imcui/third_party/rdd/training/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5a8e5e8b65ccd13b176d7d24491f6c41c2288193 --- /dev/null +++ b/imcui/third_party/rdd/training/utils.py @@ -0,0 +1,258 @@ +import torch +import numpy as np +import pdb +from copy import deepcopy +import cv2 +debug_cnt = -1 +from RDD.matchers import DualSoftmaxMatcher + +matcher = DualSoftmaxMatcher(inv_temperature = 1) + +def make_batch(augmentor, difficulty = 0.3, train = True): + Hs = [] + img_list = augmentor.train if train else augmentor.test + dev = augmentor.device + batch_images = [] + + with torch.no_grad(): # we dont require grads in the augmentation + for b in range(augmentor.batch_size): + rdidx = np.random.randint(len(img_list)) + img = torch.tensor(img_list[rdidx], dtype=torch.float32).permute(2,0,1).to(augmentor.device).unsqueeze(0) + batch_images.append(img) + + batch_images = torch.cat(batch_images) + + p1, H1 = augmentor(batch_images, difficulty) + p2, H2 = augmentor(batch_images, difficulty, TPS = True, prob_deformation = 0.7) + + return p1, p2, H1, H2 + +def plot_corrs(p1, p2, src_pts, tgt_pts): + import matplotlib.pyplot as plt + p1 = p1.cpu() + p2 = p2.cpu() + src_pts = src_pts.cpu() ; tgt_pts = tgt_pts.cpu() + rnd_idx = np.random.randint(len(src_pts), size=200) + src_pts = src_pts[rnd_idx, ...] + tgt_pts = tgt_pts[rnd_idx, ...] + + #Plot ground-truth correspondences + fig, ax = plt.subplots(1,2,figsize=(18, 12)) + colors = np.random.uniform(size=(len(tgt_pts),3)) + #Src image + img = p1 + for i, p in enumerate(src_pts): + ax[0].scatter(p[0],p[1],color=colors[i]) + ax[0].imshow(img.permute(1,2,0).numpy()[...,::-1]) + + #Target img + img2 = p2 + for i, p in enumerate(tgt_pts): + ax[1].scatter(p[0],p[1],color=colors[i]) + ax[1].imshow(img2.permute(1,2,0).numpy()[...,::-1]) + plt.show() + + +def get_corresponding_pts(p1, p2, H, H2, augmentor, h, w, crop = None): + ''' + Get dense corresponding points + ''' + global debug_cnt + negatives, positives = [], [] + + with torch.no_grad(): + #real input res of samples + rh, rw = p1.shape[-2:] + ratio = torch.tensor([rw/w, rh/h], device = p1.device) + + (H, mask1) = H + (H2, src, W, A, mask2) = H2 + + #Generate meshgrid of target pts + x, y = torch.meshgrid(torch.arange(w, device=p1.device), torch.arange(h, device=p1.device), indexing ='xy') + mesh = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1)], dim=-1) + target_pts = mesh.view(-1, 2) * ratio + + #Pack all transformations into T + for batch_idx in range(len(p1)): + with torch.no_grad(): + T = (H[batch_idx], H2[batch_idx], + src[batch_idx].unsqueeze(0), W[batch_idx].unsqueeze(0), A[batch_idx].unsqueeze(0)) + #We now warp the target points to src image + src_pts = (augmentor.get_correspondences(target_pts, T) ) #target to src + tgt_pts = (target_pts) + + #Check out of bounds points + mask_valid = (src_pts[:, 0] >=0) & (src_pts[:, 1] >=0) & \ + (src_pts[:, 0] < rw) & (src_pts[:, 1] < rh) + + negatives.append( tgt_pts[~mask_valid] ) + tgt_pts = tgt_pts[mask_valid] + src_pts = src_pts[mask_valid] + + + #Remove invalid pixels + mask_valid = mask1[batch_idx, src_pts[:,1].long(), src_pts[:,0].long()] & \ + mask2[batch_idx, tgt_pts[:,1].long(), tgt_pts[:,0].long()] + tgt_pts = tgt_pts[mask_valid] + src_pts = src_pts[mask_valid] + + # limit nb of matches if desired + if crop is not None: + rnd_idx = torch.randperm(len(src_pts), device=src_pts.device)[:crop] + src_pts = src_pts[rnd_idx] + tgt_pts = tgt_pts[rnd_idx] + + if debug_cnt >=0 and debug_cnt < 4: + plot_corrs(p1[batch_idx], p2[batch_idx], src_pts , tgt_pts ) + debug_cnt +=1 + + src_pts = (src_pts / ratio) + tgt_pts = (tgt_pts / ratio) + + #Check out of bounds points + padto = 10 if crop is not None else 2 + mask_valid1 = (src_pts[:, 0] >= (0 + padto)) & (src_pts[:, 1] >= (0 + padto)) & \ + (src_pts[:, 0] < (w - padto)) & (src_pts[:, 1] < (h - padto)) + mask_valid2 = (tgt_pts[:, 0] >= (0 + padto)) & (tgt_pts[:, 1] >= (0 + padto)) & \ + (tgt_pts[:, 0] < (w - padto)) & (tgt_pts[:, 1] < (h - padto)) + mask_valid = mask_valid1 & mask_valid2 + tgt_pts = tgt_pts[mask_valid] + src_pts = src_pts[mask_valid] + + #Remove repeated correspondences + lut_mat = torch.ones((h, w, 4), device = src_pts.device, dtype = src_pts.dtype) * -1 + # src_pts_np = src_pts.cpu().numpy() + # tgt_pts_np = tgt_pts.cpu().numpy() + try: + lut_mat[src_pts[:,1].long(), src_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1) + mask_valid = torch.all(lut_mat >= 0, dim=-1) + points = lut_mat[mask_valid] + positives.append(points) + except: + pdb.set_trace() + print('..') + + return negatives, positives + + +def crop_patches(tensor, coords, size = 7): + ''' + Crop [size x size] patches around 2D coordinates from a tensor. + ''' + B, C, H, W = tensor.shape + + x, y = coords[:, 0], coords[:, 1] + y = y.view(-1, 1, 1) + x = x.view(-1, 1, 1) + halfsize = size // 2 + # Create meshgrid for indexing + x_offset, y_offset = torch.meshgrid(torch.arange(-halfsize, halfsize+1), torch.arange(-halfsize, halfsize+1), indexing='xy') + y_offset = y_offset.to(tensor.device) + x_offset = x_offset.to(tensor.device) + + # Compute indices around each coordinate + y_indices = (y + y_offset.view(1, size, size)).squeeze(0) + halfsize + x_indices = (x + x_offset.view(1, size, size)).squeeze(0) + halfsize + + # Handle out-of-boundary indices with padding + tensor_padded = torch.nn.functional.pad(tensor, (halfsize, halfsize, halfsize, halfsize), mode='constant') + + # Index tensor to get patches + patches = tensor_padded[:, :, y_indices, x_indices] # [B, C, N, H, W] + return patches + +def subpix_softmax2d(heatmaps, temp = 0.25): + N, H, W = heatmaps.shape + heatmaps = torch.softmax(temp * heatmaps.view(-1, H*W), -1).view(-1, H, W) + x, y = torch.meshgrid(torch.arange(W, device = heatmaps.device ), torch.arange(H, device = heatmaps.device ), indexing = 'xy') + x = x - (W//2) + y = y - (H//2) + #pdb.set_trace() + coords_x = (x[None, ...] * heatmaps) + coords_y = (y[None, ...] * heatmaps) + coords = torch.cat([coords_x[..., None], coords_y[..., None]], -1).view(N, H*W, 2) + coords = coords.sum(1) + + return coords + + +def check_accuracy1(X, Y, pts1 = None, pts2 = None): + with torch.no_grad(): + #dist_mat = torch.cdist(X,Y) + dist_mat = X @ Y.t() + nn = torch.argmax(dist_mat, dim=1) + #nn = torch.argmin(dist_mat, dim=1) + correct = nn == torch.arange(len(X), device = X.device) + + acc = correct.sum().item() / len(X) + return acc + +def check_accuracy(X, Y, thr = 0.0): + with torch.no_grad(): + #dist_mat = torch.cdist(X,Y) + + inds = matcher(X[None], Y[None], thr) + batch_inds = inds[:,0] + + # count the number of inds + acc = len(batch_inds) / len(X) + + return acc + +def get_nb_trainable_params(model): + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + nb_params = sum([np.prod(p.size()) for p in model_parameters]) + + print('Number of trainable parameters: {:d}'.format(nb_params)) + + +def plot_keypoints(image, kpts, radius=2, color=(255, 0, 0)): + image = image.cpu().detach().numpy() if isinstance(image, torch.Tensor) else image + kpts = kpts.cpu().detach().numpy() if isinstance(kpts, torch.Tensor) else kpts + + if image.dtype is not np.dtype('uint8'): + image = image * 255 + image = image.astype(np.uint8) + + if len(image.shape) == 2 or image.shape[2] == 1: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + + out = np.ascontiguousarray(deepcopy(image)) + kpts = np.round(kpts).astype(int) + + for kpt in kpts: + y0, x0 = kpt + cv2.drawMarker(out, (x0, y0), color, cv2.MARKER_CROSS, radius) + + # cv2.circle(out, (x0, y0), radius, color, -1, lineType=cv2.LINE_4) + return out + +def save_image_in_actual_size(image, name): + import matplotlib.pyplot as plt + + dpi = 100 + height, width = image.shape[:2] + + # What size does the figure need to be in inches to fit the image? + figsize = width / float(dpi), height / float(dpi) + + # Create a figure of the right size with one axes that takes up the full figure + fig = plt.figure(figsize=figsize) + ax = fig.add_axes([0, 0, 1, 1]) + + # Hide spines, ticks, etc. + ax.axis('off') + + # Display the image. + if len(image.shape) == 3: + ax.imshow(image, cmap='gray') + elif len(image.shape) == 2: + if image.dtype == np.uint8: + ax.imshow(image, cmap='gray') + else: + ax.imshow(image) + ax.text(20, 20, f"Range: {image.min():g}~{image.max():g}", color='red') + + # save the image + plt.savefig(name, dpi=dpi) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index dfbf34dffb336fa9ba4d4e00f1b0746584ce8914..3e2ddc644172ecdfb35a2b27d2170ab4c3bfb5a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,6 +23,7 @@ poselib protobuf psutil pycolmap==0.6.1 +pydantic==2.10.6 pytlsd pytorch-lightning==1.4.9 PyYAML @@ -40,4 +41,3 @@ torchvision==0.19.0 tqdm uvicorn yacs -pydantic==2.10.6