Spaces:
Runtime error
Runtime error
| # -*- coding: UTF-8 -*- | |
| '''================================================= | |
| @Project -> File r2d2 -> gm | |
| @IDE PyCharm | |
| @Author [email protected] | |
| @Date 25/05/2023 10:09 | |
| ==================================================''' | |
| import torch | |
| from localization.base_model import BaseModel | |
| from nets.gm import GM as GMatcher | |
| class GM(BaseModel): | |
| default_config = { | |
| 'descriptor_dim': 128, | |
| 'hidden_dim': 256, | |
| 'weights': 'indoor', | |
| 'keypoint_encoder': [32, 64, 128, 256], | |
| 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total | |
| 'sinkhorn_iterations': 20, | |
| 'match_threshold': 0.2, | |
| 'with_pose': False, | |
| 'n_layers': 9, | |
| 'n_min_tokens': 256, | |
| 'with_sinkhorn': True, | |
| 'ac_fn': 'relu', | |
| 'norm_fn': 'bn', | |
| 'weight_path': None, | |
| } | |
| required_inputs = [ | |
| 'image0', 'keypoints0', 'scores0', 'descriptors0', | |
| 'image1', 'keypoints1', 'scores1', 'descriptors1', | |
| ] | |
| def _init(self, conf): | |
| self.net = GMatcher(config=conf).eval() | |
| state_dict = torch.load(conf['weight_path'], map_location='cpu')['model'] | |
| self.net.load_state_dict(state_dict, strict=True) | |
| def _forward(self, data): | |
| with torch.no_grad(): | |
| return self.net(data) | |