Nan Xue commited on
Commit
4c954ae
·
1 Parent(s): 3132f36
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LEGAL.md +0 -0
  2. LICENSE +21 -0
  3. README.md +117 -12
  4. gradio_demo/inference.py +252 -0
  5. gradio_demo/line_mat_gluestick.py +386 -0
  6. line_matching/run.py +191 -0
  7. line_matching/run_list.py +144 -0
  8. line_matching/two_view_pipeline.py +167 -0
  9. line_matching/wireframe.py +341 -0
  10. predictor/predict.py +131 -0
  11. requirements.txt +21 -0
  12. scalelsd/.gitignore +10 -0
  13. scalelsd/__init__.py +2 -0
  14. scalelsd/base/__init__.py +13 -0
  15. scalelsd/base/csrc/__init__.py +19 -0
  16. scalelsd/base/csrc/binding.cpp +5 -0
  17. scalelsd/base/csrc/linesegment.cu +139 -0
  18. scalelsd/base/csrc/linesegment.h +26 -0
  19. scalelsd/base/show/__init__.py +3 -0
  20. scalelsd/base/show/canvas.py +153 -0
  21. scalelsd/base/show/cli.py +24 -0
  22. scalelsd/base/show/painters.py +80 -0
  23. scalelsd/base/utils/__init__.py +1 -0
  24. scalelsd/base/utils/logger.py +30 -0
  25. scalelsd/base/utils/metric_logger.py +77 -0
  26. scalelsd/base/wireframe.py +110 -0
  27. scalelsd/encoder/__init__.py +1 -0
  28. scalelsd/encoder/hafm.py +152 -0
  29. scalelsd/ssl/backbones/__init__.py +1 -0
  30. scalelsd/ssl/backbones/build.py +28 -0
  31. scalelsd/ssl/backbones/dpt/__init__.py +0 -0
  32. scalelsd/ssl/backbones/dpt/base_model.py +16 -0
  33. scalelsd/ssl/backbones/dpt/blocks.py +388 -0
  34. scalelsd/ssl/backbones/dpt/midas_net.py +77 -0
  35. scalelsd/ssl/backbones/dpt/models.py +115 -0
  36. scalelsd/ssl/backbones/dpt/transforms.py +231 -0
  37. scalelsd/ssl/backbones/dpt/vit.py +586 -0
  38. scalelsd/ssl/backbones/multi_task_head.py +52 -0
  39. scalelsd/ssl/config/__init__.py +2 -0
  40. scalelsd/ssl/config/dataset/hpatches_dataset.yaml +105 -0
  41. scalelsd/ssl/config/dataset/nyu_dataset.yaml +77 -0
  42. scalelsd/ssl/config/dataset/official_yorkurban_dataset.yaml +75 -0
  43. scalelsd/ssl/config/dataset/rdnim_dataset.yaml +77 -0
  44. scalelsd/ssl/config/dataset/synthetic_dataset-1024.yaml +49 -0
  45. scalelsd/ssl/config/dataset/synthetic_dataset-2k.yaml +50 -0
  46. scalelsd/ssl/config/dataset/synthetic_dataset-4k.yaml +50 -0
  47. scalelsd/ssl/config/dataset/synthetic_dataset-large.yaml +50 -0
  48. scalelsd/ssl/config/dataset/synthetic_dataset.yaml +51 -0
  49. scalelsd/ssl/config/dataset/wireframe_official_gt copy.yaml +86 -0
  50. scalelsd/ssl/config/dataset/wireframe_official_gt.yaml +86 -0
LEGAL.md ADDED
File without changes
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Nan Xue
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,117 @@
1
- ---
2
- title: ScaleLSD
3
- emoji: 🌍
4
- colorFrom: indigo
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.33.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # ScaleLSD: Scalable Deep Line Segment Detection Streamlined
4
+
5
+ <!-- <a href="https://code.alipay.com/kezeran.kzr/ScaleLSD"><img src="https://img.shields.io/static/v1?label=Project%20Page&message=Github&color=blue&logo=github-pages"></a>&ensp;<a href="https://code.alipay.com/kezeran.kzr/ScaleLSD"><img src="https://img.shields.io/badge/ArXiv-250x.xxxxx-brightgreen"></a>&ensp;<a href="https://code.alipay.com/kezeran.kzr/ScaleLSD"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a>&ensp;<a href="https://code.alipay.com/kezeran.kzr/ScaleLSD"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Gradio%20Demo-Huggingface-orange"></a> -->
6
+
7
+ <a href="https://ant-research.github.io/scalelsd"><img src="https://img.shields.io/static/v1?label=Project%20Page&message=Github&color=blue&logo=github-pages"></a>&ensp;<a href="https://arxiv.org/abs/2506.09369"><img src="https://img.shields.io/badge/ArXiv-2506.09369-brightgreen"></a>&ensp;<a href="https://huggingface.co/cherubicxn/scalelsd"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a>
8
+
9
+
10
+
11
+ [Zeran Ke](https://calmke.github.io/)<sup>1,2</sup>, [Bin Tan](https://icetttb.github.io/)<sup>2</sup>, [Xianwei Zheng](https://jszy.whu.edu.cn/zhengxianwei/zh_CN/index.htm)<sup>1</sup>, [Yujun Shen](https://shenyujun.github.io/)<sup>2</sup>, [Tianfu Wu](https://research.ece.ncsu.edu/ivmcl/)<sup>3</sup>, [Nan Xue](https://xuenan.net/)<sup>2†</sup>
12
+
13
+ <sup>1</sup>Wuhan University &ensp;&ensp;<sup>2</sup>Ant Group&ensp;&ensp;<sup>3</sup>NC State University
14
+
15
+ </div>
16
+
17
+ <!-- <img src="assets/teaser.jpg" width="100%"> -->
18
+
19
+ ![teaser](assets/teaser.jpg)
20
+
21
+
22
+ ## ⚙️ Installtion
23
+
24
+ All codes are succefully tested on:
25
+
26
+ - Ubuntu 22.04.5 LTS
27
+ - CUDA 12.1
28
+ - Python 3.10
29
+ - Pytorch 2.5.1
30
+
31
+ First clone this repo:
32
+
33
+ ```bash
34
+ git clone https://github.com/ant-research/scalelsd.git
35
+ ```
36
+
37
+ Then create the conda eviroment and install the dependencies:
38
+ ```bash
39
+ conda create -n scalelsd python=3.10
40
+ pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
41
+ pip install -r requirements.txt
42
+ pip install -e . # Install scalelsd locally
43
+ ```
44
+
45
+ ## 🔥🔍 Gradio Demo
46
+
47
+ ### Line Segment Detection
48
+ Before you started, please download our pre-trained [models](https://huggingface.co/cherubicxn/scalelsd) and place them into the `models` folder. Then run the Gradio demo:
49
+ ```bash
50
+ python -m gradio_demo.inference
51
+ ```
52
+
53
+ ### Line Matching
54
+ Because our line matching app is built on GlueStick with our ScaleLSD, you need to install [GlueStick](https://github.com/cvg/GlueStick) and download the weights of the GlueStick model. Then run the Gradio demo:
55
+ ```bash
56
+ pythonb -m gradio_demo.line_mat_gluestick
57
+ ```
58
+
59
+ ## 🚗 Inference
60
+
61
+ Quickly start use our models for line segment detection by running the following command:
62
+ ```bash
63
+ python -m predictor.predict --img $[IMAGE_PATH_OR_FODER]
64
+ ```
65
+
66
+ You can also specify more params by:
67
+
68
+ ```bash
69
+ python -m predictor.predict \
70
+ --ckpt $[MODEL_PATH] \
71
+ --img $[IMAGE_PATH_OR_FODER] \
72
+ --ext $[png/pdf/json] \
73
+ --threshold 10 \
74
+ --junction-hm 0.1 \
75
+ --disable-show
76
+ ```
77
+
78
+ ```bash
79
+ OPTIONS:
80
+ --ckpt CKPT, -c CKPT
81
+ Path to the checkpoint file.
82
+ --img IMG, -i IMG Path to the image or folder containing images.
83
+ --ext EXT, -e EXT Output file extension (png/pdf/json).
84
+ --threshold THRESHOLD, -t THRESHOLD
85
+ Threshold for line segment detection.
86
+ --junction-hm JUNCTION_HM, -jh JUNCTION_HM
87
+ Junction heatmap threshold.
88
+ --num-junctions NUM_JUNCTIONS, -nj NUM_JUNCTIONS
89
+ Max number of junctions to detect.
90
+ --disable-show Disable showing the results.
91
+ --use_lsd Use LSD-Rectifier for line segment detection.
92
+ --use_nms Use Non-Maximum Suppression (NMS) for junction detection.
93
+ ```
94
+
95
+
96
+ ## 📖 Related Third-party Projects
97
+
98
+ - [HAWPv3](https://github.com/cherubicXN/hawp/tree/main)
99
+ - [DeepLSD](https://github.com/cvg/DeepLSD)
100
+ - [Progressive-x](https://github.com/danini/progressive-x/tree/vanishing-points)
101
+ - [GlueStick](https://github.com/cvg/GlueStick)
102
+ - [GlueFactory](https://github.com/cvg/glue-factory)
103
+ - [LiMAP](https://github.com/cvg/limap)
104
+
105
+
106
+ ## 📝 Citation
107
+
108
+ If you find our work useful in your research, please consider citing:
109
+
110
+ ```bash
111
+ @inproceedings{ScaleLSD,
112
+ title = {ScaleLSD: Scalable Deep Line Segment Detection Streamlined},
113
+ author = {Zeran Ke and Bin Tan and Xianwei Zheng and Yujun Shen and Tianfu Wu and Nan Xue},
114
+ booktitle = "IEEE Conference on Computer Vision and Pattern Recognition (CVPR)",
115
+ year = {2025},
116
+ }
117
+ ```
gradio_demo/inference.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import os
4
+ import gradio as gr
5
+ import numpy as np
6
+ import random
7
+ from pathlib import Path
8
+ import json
9
+
10
+ from scalelsd.ssl.models.detector import ScaleLSD
11
+ from scalelsd.base import show, WireframeGraph
12
+ from scalelsd.ssl.misc.train_utils import fix_seeds, load_scalelsd_model
13
+
14
+ # Title for the Gradio interface
15
+ _TITLE = 'Gradio Demo of ScaleLSD for Structured Representation of Images'
16
+ MAX_SEED = 1000
17
+
18
+
19
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
20
+ """random seed"""
21
+ if randomize_seed:
22
+ seed = random.randint(0, MAX_SEED)
23
+ return seed
24
+
25
+ def stop_run():
26
+ """stop run"""
27
+ return (
28
+ gr.update(value="Run", variant="primary", visible=True),
29
+ gr.update(visible=False),
30
+ )
31
+
32
+ def process_image(
33
+ input_image,
34
+ model_name='scalelsd-vitbase-v2-train-sa1b.pt',
35
+ save_name='temp_output',
36
+ threshold=10,
37
+ junction_threshold_hm=0.008,
38
+ num_junctions_inference=512,
39
+ width=512,
40
+ height=512,
41
+ line_width=2,
42
+ juncs_size=4,
43
+ whitebg=0.0,
44
+ draw_junctions_only=False,
45
+ use_lsd=False,
46
+ use_nms=False,
47
+ edge_color='orange',
48
+ vertex_color='Cyan',
49
+ output_format='png',
50
+ seed=0,
51
+ randomize_seed=False
52
+ ):
53
+ """core processing function for image inference"""
54
+ # set random seed
55
+ seed = int(randomize_seed_fn(seed, randomize_seed))
56
+ fix_seeds(seed)
57
+
58
+ # initialize model
59
+ ckpt = "models/" + model_name
60
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ model = load_scalelsd_model(ckpt, device)
62
+
63
+ # set model parameters
64
+ model.junction_threshold_hm = junction_threshold_hm
65
+ model.num_junctions_inference = num_junctions_inference
66
+
67
+ # transform input image
68
+ if isinstance(input_image, np.ndarray):
69
+ image = cv2.cvtColor(input_image, cv2.COLOR_RGB2GRAY)
70
+ else:
71
+ image = cv2.imread(input_image, 0)
72
+
73
+ # resize
74
+ ori_shape = image.shape[:2]
75
+ image_resized = cv2.resize(image.copy(), (width, height))
76
+ image_tensor = torch.from_numpy(image_resized).float() / 255.0
77
+ image_tensor = image_tensor[None, None].to('cuda')
78
+
79
+ # meta data
80
+ meta = {
81
+ 'width': ori_shape[1],
82
+ 'height': ori_shape[0],
83
+ 'filename': '',
84
+ 'use_lsd': use_lsd,
85
+ 'use_nms': use_nms,
86
+ }
87
+
88
+ # inference
89
+ with torch.no_grad():
90
+ outputs, _ = model(image_tensor, meta)
91
+ outputs = outputs[0]
92
+
93
+ # visual results
94
+ painter = show.painters.HAWPainter()
95
+ painter.confidence_threshold = threshold
96
+ painter.line_width = line_width
97
+ painter.marker_size = juncs_size
98
+ if whitebg > 0.0:
99
+ show.Canvas.white_overlay = whitebg
100
+
101
+ temp_folder = "temp_output"
102
+ os.makedirs(temp_folder, exist_ok=True)
103
+ fig_file = f"{temp_folder}/{save_name}.png"
104
+ with show.image_canvas(input_image, fig_file=fig_file) as ax:
105
+ if draw_junctions_only:
106
+ painter.draw_junctions(ax, outputs)
107
+ else:
108
+ painter.draw_wireframe(ax, outputs, edge_color=edge_color, vertex_color=vertex_color)
109
+ # read the result image
110
+ result_image = cv2.imread(fig_file)
111
+
112
+ if output_format != 'png':
113
+ fig_file = f"{temp_folder}/{save_name}.{output_format}"
114
+ with show.image_canvas(input_image, fig_file=fig_file) as ax:
115
+ if draw_junctions_only:
116
+ painter.draw_junctions(ax, outputs)
117
+ else:
118
+ painter.draw_wireframe(ax, outputs, edge_color=edge_color, vertex_color=vertex_color)
119
+
120
+ json_file = f"{temp_folder}/{save_name}.json"
121
+ indices = WireframeGraph.xyxy2indices(outputs['juncs_pred'],outputs['lines_pred'])
122
+ wireframe = WireframeGraph(outputs['juncs_pred'], outputs['juncs_score'], indices, outputs['lines_score'], outputs['width'], outputs['height'])
123
+ with open(json_file, 'w') as f:
124
+ json.dump(wireframe.jsonize(),f)
125
+
126
+
127
+ return result_image[:, :, ::-1], json_file, fig_file
128
+
129
+
130
+ def run_demo():
131
+ """create the Gradio demo interface"""
132
+ css = """
133
+ #col-container {
134
+ margin: 0 auto;
135
+ max-width: 800px;
136
+ }
137
+ """
138
+
139
+ with gr.Blocks(css=css, title=_TITLE) as demo:
140
+ with gr.Column(elem_id="col-container"):
141
+ gr.Markdown(f'# {_TITLE}')
142
+ gr.Markdown("Detect wireframe structures in images using ScaleLSD model")
143
+
144
+ pid = gr.State()
145
+ figs_root = "assets/figs"
146
+ example_images = [os.path.join(figs_root, iname) for iname in os.listdir(figs_root)]
147
+
148
+ with gr.Row():
149
+ input_image = gr.Image(example_images[0], label="Input Image", type="numpy")
150
+ output_image = gr.Image(label="Detection Result")
151
+
152
+ with gr.Row():
153
+ run_btn = gr.Button(value="Run", variant="primary")
154
+ stop_btn = gr.Button(value="Stop", variant="stop", visible=False)
155
+
156
+ with gr.Row():
157
+ json_file = gr.File(label="Download JSON Output", type="filepath")
158
+ image_file = gr.File(label="Download Image Output", type="filepath")
159
+
160
+ with gr.Accordion("Advanced Settings", open=True):
161
+ with gr.Row():
162
+ model_name = gr.Dropdown(
163
+ [ckpt for ckpt in os.listdir('models') if ckpt.endswith('.pt')],
164
+ value='scalelsd-vitbase-v1-train-sa1b.pt',
165
+ label="Model Selection"
166
+ )
167
+
168
+ with gr.Row():
169
+ save_name = gr.Textbox('temp_output', label="Save Name", placeholder="Name for saving output files")
170
+
171
+ with gr.Row():
172
+ with gr.Column():
173
+ threshold = gr.Number(10, label="Line Threshold")
174
+ junction_threshold_hm = gr.Number(0.008, label="Junction Threshold")
175
+ num_junctions_inference = gr.Number(1024, label="Max Number of Junctions")
176
+ width = gr.Number(512, label="Input Width")
177
+ height = gr.Number(512, label="Input Height")
178
+
179
+ with gr.Column():
180
+ draw_junctions_only = gr.Checkbox(False, label="Show Junctions Only")
181
+ use_lsd = gr.Checkbox(False, label="Use LSD-Rectifier")
182
+ use_nms = gr.Checkbox(True, label="Use NMS")
183
+ output_format = gr.Dropdown(
184
+ ['png', 'jpg', 'pdf'],
185
+ value='png',
186
+ label="Output Format"
187
+ )
188
+ whitebg = gr.Slider(0.0, 1.0, value=0.7, label="White Background Opacity")
189
+ line_width = gr.Number(2, label="Line Width")
190
+ juncs_size = gr.Number(8, label="Junctions Size")
191
+
192
+ with gr.Row():
193
+ edge_color = gr.Dropdown(
194
+ ['orange', 'midnightblue', 'red', 'green'],
195
+ value='orange',
196
+ label="Edge Color"
197
+ )
198
+ vertex_color = gr.Dropdown(
199
+ ['Cyan', 'deeppink', 'yellow', 'purple'],
200
+ value='Cyan',
201
+ label="Vertex Color"
202
+ )
203
+
204
+ with gr.Row():
205
+ randomize_seed = gr.Checkbox(False, label="Randomize Seed")
206
+ seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed")
207
+
208
+ gr.Examples(
209
+ examples=example_images,
210
+ inputs=input_image,
211
+ )
212
+
213
+ # star event handlers
214
+ run_event = run_btn.click(
215
+ fn=process_image,
216
+ inputs=[
217
+ input_image,
218
+ model_name,
219
+ save_name,
220
+ threshold,
221
+ junction_threshold_hm,
222
+ num_junctions_inference,
223
+ width,
224
+ height,
225
+ line_width,
226
+ juncs_size,
227
+ whitebg,
228
+ draw_junctions_only,
229
+ use_lsd,
230
+ use_nms,
231
+ edge_color,
232
+ vertex_color,
233
+ output_format,
234
+ seed,
235
+ randomize_seed
236
+ ],
237
+ outputs=[output_image, json_file, image_file],
238
+ )
239
+
240
+ # stop event handlers
241
+ stop_btn.click(
242
+ fn=stop_run,
243
+ outputs=[run_btn, stop_btn],
244
+ cancels=[run_event],
245
+ queue=False,
246
+ )
247
+
248
+
249
+ return demo
250
+
251
+ if __name__ == "__main__":
252
+ run_demo().launch()
gradio_demo/line_mat_gluestick.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from os.path import join
4
+ import sys
5
+ import numpy as np
6
+ import cv2
7
+ import torch
8
+ from matplotlib import pyplot as plt
9
+ from tqdm import tqdm
10
+ import gradio as gr
11
+ import random
12
+
13
+ from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT
14
+ from gluestick.drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches
15
+
16
+ from scalelsd.ssl.models.detector import ScaleLSD
17
+ from scalelsd.base import show, WireframeGraph
18
+ from scalelsd.ssl.datasets.transforms.homographic_transforms import sample_homography
19
+ from scalelsd.ssl.misc.train_utils import fix_seeds
20
+ from line_matching.two_view_pipeline import TwoViewPipeline
21
+
22
+ from kornia.geometry import warp_perspective,transform_points
23
+
24
+ class HADConfig:
25
+ num_iter = 1
26
+ valid_border_margin = 3
27
+ translation = True
28
+ rotation = True
29
+ scale = True
30
+ perspective = True
31
+ scaling_amplitude = 0.2
32
+ perspective_amplitude_x = 0.2
33
+ perspective_amplitude_y = 0.2
34
+ allow_artifacts = False
35
+ patch_ratio = 0.85
36
+ had_cfg = HADConfig()
37
+
38
+ # Evaluation config
39
+ default_conf = {
40
+ 'name': 'two_view_pipeline',
41
+ 'use_lines': True,
42
+ 'extractor': {
43
+ 'name': 'wireframe',
44
+ 'sp_params': {
45
+ 'force_num_keypoints': False,
46
+ 'max_num_keypoints': 2048,
47
+ },
48
+ 'wireframe_params': {
49
+ 'merge_points': True,
50
+ 'merge_line_endpoints': True,
51
+ # 'merge_line_endpoints': False,
52
+ },
53
+ 'max_n_lines': 512,
54
+ },
55
+ 'matcher': {
56
+ 'name': 'gluestick',
57
+ 'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'),
58
+ 'trainable': False,
59
+ },
60
+ 'ground_truth': {
61
+ 'from_pose_depth': False,
62
+ }
63
+ }
64
+
65
+ # Title for the Gradio interface
66
+ _TITLE = 'ScaleLSD-GlueStick Line Matching'
67
+ MAX_SEED = 1000
68
+
69
+ def sample_homographics(height, width):
70
+
71
+ def scale_homography(H, stride):
72
+ H_scaled = H.clone()
73
+ H_scaled[:, :, 2, :2] *= stride
74
+ H_scaled[:, :, :2, 2] /= stride
75
+ return H_scaled
76
+
77
+ homographic = sample_homography(
78
+ shape = (height, width),
79
+ perspective = had_cfg.perspective,
80
+ scaling = had_cfg.scale,
81
+ rotation = had_cfg.rotation,
82
+ translation = had_cfg.translation,
83
+ scaling_amplitude = had_cfg.scaling_amplitude,
84
+ perspective_amplitude_x = had_cfg.perspective_amplitude_x,
85
+ perspective_amplitude_y = had_cfg.perspective_amplitude_y,
86
+ patch_ratio = had_cfg.patch_ratio,
87
+ allow_artifacts = False
88
+ )[0]
89
+
90
+ homographic = torch.from_numpy(homographic[None]).float().cuda()
91
+ homographic_inv = torch.inverse(homographic)
92
+
93
+ H = {
94
+ 'h.1': homographic,
95
+ 'ih.1': homographic_inv,
96
+ }
97
+
98
+ return H
99
+
100
+ def trans_image_with_homograpy(image):
101
+ h, w = image.shape[:2]
102
+ H = sample_homographics(height=h, width=w)
103
+
104
+ image_warped = warp_perspective(torch.Tensor(image).permute(2,0,1)[None].cuda(), H['h.1'], (h,w))
105
+ image_warped_ = image_warped[0].permute(1,2,0).cpu().numpy().astype(np.uint8)
106
+ plt.imshow(image_warped_)
107
+ plt.show()
108
+ return image_warped_
109
+
110
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
111
+ """random seed"""
112
+ if randomize_seed:
113
+ seed = random.randint(0, MAX_SEED)
114
+ return seed
115
+
116
+ def stop_run():
117
+ """stop run"""
118
+ return (
119
+ gr.update(value="Run", variant="primary", visible=True),
120
+ gr.update(visible=False),
121
+ )
122
+
123
+ def clear_image2():
124
+ return None # returning None will clear the image component
125
+
126
+ def process_image(
127
+ input_image1='assets/figs/sa_1119229.jpg',
128
+ input_image2=None,
129
+ model_name='scalelsd-vitbase-v1-train-sa1b.pt',
130
+ save_name='temp',
131
+ threshold=5,
132
+ junction_threshold_hm=0.008,
133
+ num_junctions_inference=4096,
134
+ width=512,
135
+ height=512,
136
+ line_width=2,
137
+ juncs_size=4,
138
+ whitebg=1.0,
139
+ draw_junctions_only=False,
140
+ use_lsd=False,
141
+ use_nms=False,
142
+ edge_color='midnightblue',
143
+ vertex_color='deeppink',
144
+ output_format='png',
145
+ seed=0,
146
+ randomize_seed=False
147
+ ):
148
+ """core processing function for image inference"""
149
+ # set random seed
150
+ seed = int(randomize_seed_fn(seed, randomize_seed))
151
+ fix_seeds(seed)
152
+
153
+ conf = {
154
+ 'model_name': model_name,
155
+ 'threshold': threshold,
156
+ 'junction_threshold_hm': junction_threshold_hm,
157
+ 'num_junctions_inference': num_junctions_inference,
158
+ 'use_lsd': use_lsd,
159
+ 'use_nms': use_nms,
160
+ 'width': width,
161
+ 'height': height,
162
+ }
163
+
164
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
165
+ pipeline_model = TwoViewPipeline(default_conf).to(device).eval()
166
+ pipeline_model.extractor.update_conf(conf)
167
+
168
+ saveto = f'temp_output/matching_results'
169
+ image1 = cv2.cvtColor(input_image1, cv2.COLOR_BGR2RGB)
170
+ cv2.imwrite(f'{saveto}/image.png', image1)
171
+ input_image1 = f'{saveto}/image.png'
172
+ if input_image2 is None:
173
+ image2 = trans_image_with_homograpy(image1)
174
+ else:
175
+ image2 = cv2.cvtColor(input_image2, cv2.COLOR_BGR2RGB)
176
+ cv2.imwrite(f'{saveto}/image2.png', image2)
177
+ input_image2 = f'{saveto}/image2.png'
178
+
179
+ gray0 = cv2.imread(input_image1, 0)
180
+ gray1 = cv2.imread(input_image2, 0)
181
+
182
+ torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1)
183
+ torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None]
184
+
185
+ x = {'image0': torch_gray0, 'image1': torch_gray1}
186
+ pred = pipeline_model(x)
187
+
188
+ pred = batch_to_np(pred)
189
+ kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
190
+ m0 = pred["matches0"]
191
+
192
+ line_seg0, line_seg1 = pred["lines0"], pred["lines1"]
193
+ line_matches = pred["line_matches0"]
194
+
195
+ valid_matches = m0 != -1
196
+ match_indices = m0[valid_matches]
197
+ matched_kps0 = kp0[valid_matches]
198
+ matched_kps1 = kp1[match_indices]
199
+
200
+ valid_matches = line_matches != -1
201
+ match_indices = line_matches[valid_matches]
202
+ matched_lines0 = line_seg0[valid_matches]
203
+ matched_lines1 = line_seg1[match_indices]
204
+
205
+ img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR)
206
+
207
+ mat_file = f'{saveto}/{save_name}_mat.png'
208
+ plot_images([img0, img1], dpi=200, pad=2.0)
209
+ plot_lines([line_seg0, line_seg1], ps=4, lw=2)
210
+ plt.gcf().canvas.manager.set_window_title('Detected Lines')
211
+ # plt.tight_layout()
212
+ plt.savefig(mat_file)
213
+ det_image = cv2.imread(mat_file)[:,:,::-1]
214
+
215
+ det_file = f'{saveto}/{save_name}_mat.png'
216
+ plot_images([img0, img1], dpi=200, pad=2.0)
217
+ plot_color_line_matches([matched_lines0, matched_lines1], lw=3)
218
+ plt.gcf().canvas.manager.set_window_title('Line Matches')
219
+ # plt.tight_layout()
220
+ plt.savefig(det_file)
221
+ mat_image = cv2.imread(det_file)[:,:,::-1]
222
+
223
+ show.Canvas.white_overlay = whitebg
224
+ painter = show.painters.HAWPainter()
225
+
226
+ fig_file = f'{saveto}/{save_name}_det1.png'
227
+ outputs = {'lines_pred': line_seg0.reshape(-1,4)}
228
+ with show.image_canvas(input_image1, fig_file=fig_file) as ax:
229
+ painter.draw_wireframe(ax,outputs, edge_color=edge_color, vertex_color=vertex_color)
230
+ det1_image = cv2.imread(fig_file)[:,:,::-1]
231
+
232
+ fig_file = f'{saveto}/{save_name}_det2.png'
233
+ outputs = {'lines_pred': line_seg1.reshape(-1,4)}
234
+ with show.image_canvas(input_image2, fig_file=fig_file) as ax:
235
+ painter.draw_wireframe(ax,outputs, edge_color=edge_color, vertex_color=vertex_color)
236
+ det2_image = cv2.imread(fig_file)[:,:,::-1]
237
+
238
+ return image2[:,:,::-1], mat_image, det_image, det1_image, det2_image, mat_file, det_file
239
+
240
+
241
+ def demo():
242
+ """create the Gradio demo interface"""
243
+ css = """
244
+ #col-container {
245
+ margin: 0 auto;
246
+ max-width: 800px;
247
+ }
248
+ """
249
+
250
+ with gr.Blocks(css=css, title=_TITLE) as demo:
251
+ with gr.Column(elem_id="col-container"):
252
+ gr.Markdown(f'# {_TITLE}')
253
+ gr.Markdown("Detect wireframe structures in images using ScaleLSD model")
254
+
255
+ pid = gr.State()
256
+ figs_root = "assets/mat_figs"
257
+ example_single = [os.path.join(figs_root, 'single', iname) for iname in os.listdir(figs_root+'/single')]
258
+ example_pairs = [[img, None] for img in example_single]
259
+ example_pairs += [
260
+ [os.path.join(figs_root, 'pairs', f'ref_{i}.png'),
261
+ os.path.join(figs_root, 'pairs', f'tgt_{i}.png')]
262
+ for i in [10, 72, 76, 95, 149, 151]
263
+ ]
264
+
265
+ with gr.Row():
266
+ input_image1 = gr.Image(example_pairs[0][0], label="Input Image1", type="numpy")
267
+ input_image2 = gr.Image(label="Input Image2", type="numpy")
268
+
269
+ with gr.Row():
270
+ mat_images = gr.Image(label="Matching Results")
271
+ with gr.Row():
272
+ det_images = gr.Image(label="Detection Results")
273
+ with gr.Row():
274
+ det_image1 = gr.Image(label="Detection1")
275
+ det_image2 = gr.Image(label="Detection2")
276
+
277
+ with gr.Row():
278
+ run_btn = gr.Button(value="Run", variant="primary")
279
+ stop_btn = gr.Button(value="Stop", variant="stop", visible=False)
280
+
281
+ with gr.Row():
282
+ mat_file = gr.File(label="Download Matching Result", type="filepath")
283
+ det_file = gr.File(label="Download Detection Result", type="filepath")
284
+
285
+ with gr.Accordion("Advanced Settings", open=True):
286
+ with gr.Row():
287
+ model_name = gr.Dropdown(
288
+ [ckpt for ckpt in os.listdir('models') if ckpt.endswith('.pt')],
289
+ value='scalelsd-vitbase-v1-train-sa1b.pt',
290
+ label="Model Selection"
291
+ )
292
+
293
+ with gr.Row():
294
+ save_name = gr.Textbox('temp_output', label="Save Name", placeholder="Name for saving output files")
295
+
296
+ with gr.Row():
297
+ with gr.Column():
298
+ threshold = gr.Number(10, label="Line Threshold")
299
+ junction_threshold_hm = gr.Number(0.008, label="Junction Threshold")
300
+ num_junctions_inference = gr.Number(1024, label="Max Number of Junctions")
301
+ width = gr.Number(512, label="Input Width")
302
+ height = gr.Number(512, label="Input Height")
303
+
304
+ with gr.Column():
305
+ draw_junctions_only = gr.Checkbox(False, label="Show Junctions Only")
306
+ use_lsd = gr.Checkbox(False, label="Use LSD-Rectifier")
307
+ use_nms = gr.Checkbox(True, label="Use NMS")
308
+ output_format = gr.Dropdown(
309
+ ['png', 'jpg', 'pdf'],
310
+ value='png',
311
+ label="Output Format"
312
+ )
313
+ whitebg = gr.Slider(0.0, 1.0, value=1.0, label="White Background Opacity")
314
+ line_width = gr.Number(2, label="Line Width")
315
+ juncs_size = gr.Number(8, label="Junctions Size")
316
+
317
+ with gr.Row():
318
+ edge_color = gr.Dropdown(
319
+ ['orange', 'midnightblue', 'red', 'green'],
320
+ value='midnightblue',
321
+ label="Edge Color"
322
+ )
323
+ vertex_color = gr.Dropdown(
324
+ ['Cyan', 'deeppink', 'yellow', 'purple'],
325
+ value='deeppink',
326
+ label="Vertex Color"
327
+ )
328
+
329
+ with gr.Row():
330
+ randomize_seed = gr.Checkbox(False, label="Randomize Seed")
331
+ seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed")
332
+
333
+ gr.Examples(
334
+ examples=example_pairs,
335
+ inputs=[input_image1, input_image2]
336
+ )
337
+
338
+ # star event handlers
339
+ run_event = run_btn.click(
340
+ fn=process_image,
341
+ inputs=[
342
+ input_image1,
343
+ input_image2,
344
+ model_name,
345
+ save_name,
346
+ threshold,
347
+ junction_threshold_hm,
348
+ num_junctions_inference,
349
+ width,
350
+ height,
351
+ line_width,
352
+ juncs_size,
353
+ whitebg,
354
+ draw_junctions_only,
355
+ use_lsd,
356
+ use_nms,
357
+ edge_color,
358
+ vertex_color,
359
+ output_format,
360
+ seed,
361
+ randomize_seed
362
+ ],
363
+ outputs=[input_image2, mat_images, det_images, det_image1, det_image2, mat_file, det_file],
364
+ )
365
+
366
+ # stop event handlers
367
+ stop_btn.click(
368
+ fn=stop_run,
369
+ outputs=[run_btn, stop_btn],
370
+ cancels=[run_event],
371
+ queue=False,
372
+ )
373
+
374
+ # When image1 changes, image2 is cleared
375
+ input_image1.change(
376
+ fn=clear_image2,
377
+ outputs=input_image2
378
+ )
379
+
380
+
381
+ return demo
382
+
383
+ if __name__ == "__main__":
384
+ # 启动应用
385
+ demo = demo()
386
+ demo.launch()
line_matching/run.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from os.path import join
4
+ import sys
5
+ import numpy as np
6
+ import cv2
7
+ import torch
8
+ from matplotlib import pyplot as plt
9
+ from tqdm import tqdm
10
+
11
+ from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT
12
+ from gluestick.drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches
13
+ from line_matching.two_view_pipeline import TwoViewPipeline
14
+
15
+ from scalelsd.base import show, WireframeGraph
16
+ from scalelsd.ssl.datasets.transforms.homographic_transforms import sample_homography
17
+ from kornia.geometry import warp_perspective,transform_points
18
+
19
+ class HADConfig:
20
+ num_iter = 1
21
+ valid_border_margin = 3
22
+ translation = True
23
+ rotation = True
24
+ scale = True
25
+ perspective = True
26
+ scaling_amplitude = 0.2
27
+ perspective_amplitude_x = 0.2
28
+ perspective_amplitude_y = 0.2
29
+ allow_artifacts = False
30
+ patch_ratio = 0.85
31
+ had_cfg = HADConfig()
32
+
33
+ def sample_homographics(height, width):
34
+
35
+ def scale_homography(H, stride):
36
+ H_scaled = H.clone()
37
+ H_scaled[:, :, 2, :2] *= stride
38
+ H_scaled[:, :, :2, 2] /= stride
39
+ return H_scaled
40
+
41
+ homographic = sample_homography(
42
+ shape = (height, width),
43
+ perspective = had_cfg.perspective,
44
+ scaling = had_cfg.scale,
45
+ rotation = had_cfg.rotation,
46
+ translation = had_cfg.translation,
47
+ scaling_amplitude = had_cfg.scaling_amplitude,
48
+ perspective_amplitude_x = had_cfg.perspective_amplitude_x,
49
+ perspective_amplitude_y = had_cfg.perspective_amplitude_y,
50
+ patch_ratio = had_cfg.patch_ratio,
51
+ allow_artifacts = False
52
+ )[0]
53
+
54
+ homographic = torch.from_numpy(homographic[None]).float().cuda()
55
+ homographic_inv = torch.inverse(homographic)
56
+
57
+ H = {
58
+ 'h.1': homographic,
59
+ 'ih.1': homographic_inv,
60
+ }
61
+
62
+ return H
63
+
64
+ def trans_image_with_homograpy(image):
65
+ h, w = image.shape[:2]
66
+ H = sample_homographics(height=h, width=w)
67
+
68
+ image_warped = warp_perspective(torch.Tensor(image).permute(2,0,1)[None].cuda(), H['h.1'], (h,w))
69
+ image_warped_ = image_warped[0].permute(1,2,0).cpu().numpy().astype(np.uint8)
70
+ plt.imshow(image_warped_)
71
+ plt.show()
72
+ return image_warped_
73
+
74
+
75
+ def main():
76
+ # Parse input parameters
77
+ parser = argparse.ArgumentParser(
78
+ prog='GlueStick Demo',
79
+ description='Demo app to show the point and line matches obtained by GlueStick')
80
+ parser.add_argument('-img1', default='assets/figs/sa_1119229.jpg')
81
+ parser.add_argument('-img2', default=None)
82
+ parser.add_argument('--max_pts', type=int, default=1000)
83
+ parser.add_argument('--max_lines', type=int, default=300)
84
+ parser.add_argument('--model', type=str, default='models/paper-sa1b-997pkgs-model.pt')
85
+ args = parser.parse_args()
86
+
87
+ # important
88
+ if args.img1 is None and args.img2 is None:
89
+ raise ValueError("Input at least one path of image1 or image2")
90
+
91
+ # Evaluation config
92
+ conf = {
93
+ 'name': 'two_view_pipeline',
94
+ 'use_lines': True,
95
+ 'extractor': {
96
+ 'name': 'wireframe',
97
+ 'sp_params': {
98
+ 'force_num_keypoints': False,
99
+ 'max_num_keypoints': args.max_pts,
100
+ },
101
+ 'wireframe_params': {
102
+ 'merge_points': True,
103
+ 'merge_line_endpoints': True,
104
+ # 'merge_line_endpoints': False,
105
+ },
106
+ 'max_n_lines': args.max_lines,
107
+ },
108
+ 'matcher': {
109
+ 'name': 'gluestick',
110
+ 'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'),
111
+ 'trainable': False,
112
+ },
113
+ 'ground_truth': {
114
+ 'from_pose_depth': False,
115
+ }
116
+ }
117
+
118
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
119
+ pipeline_model = TwoViewPipeline(conf).to(device).eval()
120
+ pipeline_model.extractor.update_conf(None)
121
+
122
+ saveto = f'temp_output/matching_results'
123
+ os.makedirs(saveto, exist_ok=True)
124
+
125
+ image1 = cv2.cvtColor(cv2.imread(args.img1), cv2.COLOR_BGR2RGB)
126
+ if args.img2 is None:
127
+ image2 = trans_image_with_homograpy(image1)
128
+ cv2.imwrite(f'{saveto}/warped_image.png', image2)
129
+ args.img2 = f'{saveto}/warped_image.png'
130
+
131
+ gray0 = cv2.imread(args.img1, 0)
132
+ gray1 = cv2.imread(args.img2, 0)
133
+
134
+ torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1)
135
+ torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None]
136
+
137
+ x = {'image0': torch_gray0, 'image1': torch_gray1}
138
+ pred = pipeline_model(x)
139
+
140
+ pred = batch_to_np(pred)
141
+ kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
142
+ m0 = pred["matches0"]
143
+
144
+ line_seg0, line_seg1 = pred["lines0"], pred["lines1"]
145
+ line_matches = pred["line_matches0"]
146
+
147
+ valid_matches = m0 != -1
148
+ match_indices = m0[valid_matches]
149
+ matched_kps0 = kp0[valid_matches]
150
+ matched_kps1 = kp1[match_indices]
151
+
152
+ valid_matches = line_matches != -1
153
+ match_indices = line_matches[valid_matches]
154
+ matched_lines0 = line_seg0[valid_matches]
155
+ matched_lines1 = line_seg1[match_indices]
156
+
157
+ # Plot the matches
158
+ gray0 = cv2.imread(args.img1, 0)
159
+ gray1 = cv2.imread(args.img2, 0)
160
+ img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR)
161
+
162
+ plot_images([img0, img1], dpi=200, pad=2.0)
163
+ plot_lines([line_seg0, line_seg1], ps=4, lw=2)
164
+ plt.gcf().canvas.manager.set_window_title('Detected Lines')
165
+ # plt.tight_layout()
166
+ plt.savefig(f'{saveto}/det.png')
167
+
168
+ plot_images([img0, img1], dpi=200, pad=2.0)
169
+ plot_color_line_matches([matched_lines0, matched_lines1], lw=3)
170
+ plt.gcf().canvas.manager.set_window_title('Line Matches')
171
+ # plt.tight_layout()
172
+ plt.savefig(f'{saveto}/mat.png')
173
+
174
+ whitebg = 1
175
+ show.Canvas.white_overlay = whitebg
176
+ painter = show.painters.HAWPainter()
177
+
178
+ fig_file = f'{saveto}/det1.png'
179
+ outputs = {'lines_pred': line_seg0.reshape(-1,4)}
180
+ with show.image_canvas(args.img1, fig_file=fig_file) as ax:
181
+ # painter.draw_wireframe(ax,outputs, edge_color='orange', vertex_color='Cyan')
182
+ painter.draw_wireframe(ax,outputs, edge_color='midnightblue', vertex_color='deeppink')
183
+ fig_file = f'{saveto}/det2.png'
184
+ outputs = {'lines_pred': line_seg1.reshape(-1,4)}
185
+ with show.image_canvas(args.img2, fig_file=fig_file) as ax:
186
+ painter.draw_wireframe(ax,outputs, edge_color='midnightblue', vertex_color='deeppink')
187
+
188
+
189
+
190
+ if __name__ == '__main__':
191
+ main()
line_matching/run_list.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from os.path import join
4
+ import sys
5
+
6
+ import cv2
7
+ import torch
8
+ from matplotlib import pyplot as plt
9
+ from tqdm import tqdm
10
+
11
+ from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT
12
+ from gluestick.drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches
13
+ # from gluestick.models.two_view_pipeline import TwoViewPipeline
14
+ from line_matching.two_view_pipeline import TwoViewPipeline
15
+
16
+ from scalelsd.base import show, WireframeGraph
17
+
18
+ def main():
19
+ # Parse input parameters
20
+ parser = argparse.ArgumentParser(
21
+ prog='GlueStick Demo',
22
+ description='Demo app to show the point and line matches obtained by GlueStick')
23
+ parser.add_argument('-inum', default=None, type=int)
24
+ parser.add_argument('-imax', default=None, type=int)
25
+ parser.add_argument('-img1', default=join('resources' + os.path.sep + 'img1.jpg'))
26
+ parser.add_argument('-img2', default=join('resources' + os.path.sep + 'img2.jpg'))
27
+ parser.add_argument('--max_pts', type=int, default=1000)
28
+ parser.add_argument('--max_lines', type=int, default=300)
29
+ parser.add_argument('--model', default='scalelsd', type=str)
30
+ parser.add_argument('--test_root', type=str, default='data-ssl/0images-pre/')
31
+ args = parser.parse_args()
32
+
33
+ # Evaluation config
34
+ conf = {
35
+ 'name': 'two_view_pipeline',
36
+ 'use_lines': True,
37
+ 'extractor': {
38
+ 'name': 'wireframe',
39
+ 'sp_params': {
40
+ 'force_num_keypoints': False,
41
+ 'max_num_keypoints': args.max_pts,
42
+ },
43
+ 'wireframe_params': {
44
+ 'merge_points': True,
45
+ 'merge_line_endpoints': True,
46
+ # 'merge_line_endpoints': False,
47
+ },
48
+ 'max_n_lines': args.max_lines,
49
+ },
50
+ 'matcher': {
51
+ 'name': 'gluestick',
52
+ 'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'),
53
+ 'trainable': False,
54
+ },
55
+ 'ground_truth': {
56
+ 'from_pose_depth': False,
57
+ }
58
+ }
59
+
60
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
61
+ pipeline_model = TwoViewPipeline(conf).to(device).eval()
62
+
63
+ pipeline_model.extractor.update_conf(None)
64
+
65
+ md = args.model
66
+
67
+ root = args.test_root
68
+ if args.inum is not None:
69
+ ids = [args.inum]
70
+ elif args.imax is not None:
71
+ ids = range(args.inum, args.imax+1)
72
+ else:
73
+ l_imgs = int(len(os.listdir(root))/2)
74
+ ids = range(l_imgs)
75
+
76
+ for id in tqdm(ids):
77
+ saveto = f'temp_output/matching_results/{md}/{id}'
78
+ os.makedirs(saveto, exist_ok=True)
79
+
80
+ args.img1 = root + f'ref_{str(id)}.png'
81
+ args.img2 = root + f'tgt_{str(id)}.png'
82
+
83
+ gray0 = cv2.imread(args.img1, 0)
84
+ gray1 = cv2.imread(args.img2, 0)
85
+
86
+ torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1)
87
+ torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None]
88
+
89
+ x = {'image0': torch_gray0, 'image1': torch_gray1}
90
+ pred = pipeline_model(x)
91
+
92
+ pred = batch_to_np(pred)
93
+ kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
94
+ m0 = pred["matches0"]
95
+
96
+ line_seg0, line_seg1 = pred["lines0"], pred["lines1"]
97
+ line_matches = pred["line_matches0"]
98
+
99
+ valid_matches = m0 != -1
100
+ match_indices = m0[valid_matches]
101
+ matched_kps0 = kp0[valid_matches]
102
+ matched_kps1 = kp1[match_indices]
103
+
104
+ valid_matches = line_matches != -1
105
+ match_indices = line_matches[valid_matches]
106
+ matched_lines0 = line_seg0[valid_matches]
107
+ matched_lines1 = line_seg1[match_indices]
108
+
109
+ # Plot the matches
110
+ gray0 = cv2.imread(args.img1, 0)
111
+ gray1 = cv2.imread(args.img2, 0)
112
+ img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR)
113
+
114
+ plot_images([img0, img1], dpi=200, pad=2.0)
115
+ plot_lines([line_seg0, line_seg1], ps=4, lw=2)
116
+ plt.gcf().canvas.manager.set_window_title('Detected Lines')
117
+ # plt.tight_layout()
118
+ plt.savefig(f'{saveto}/{md}_det_{id}.png')
119
+
120
+ plot_images([img0, img1], dpi=200, pad=2.0)
121
+ plot_color_line_matches([matched_lines0, matched_lines1], lw=3)
122
+ plt.gcf().canvas.manager.set_window_title('Line Matches')
123
+ # plt.tight_layout()
124
+ plt.savefig(f'{saveto}/{md}_mat_{id}.png')
125
+
126
+ whitebg = 1
127
+ show.Canvas.white_overlay = whitebg
128
+ painter = show.painters.HAWPainter()
129
+
130
+ fig_file = f'{saveto}/{md}_det1.png'
131
+ outputs = {'lines_pred': line_seg0.reshape(-1,4)}
132
+ with show.image_canvas(args.img1, fig_file=fig_file) as ax:
133
+ # painter.draw_wireframe(ax,outputs, edge_color='orange', vertex_color='Cyan')
134
+ painter.draw_wireframe(ax,outputs, edge_color='midnightblue', vertex_color='deeppink')
135
+ fig_file = f'{saveto}/{md}_det2.png'
136
+ outputs = {'lines_pred': line_seg1.reshape(-1,4)}
137
+ with show.image_canvas(args.img2, fig_file=fig_file) as ax:
138
+ # painter.draw_wireframe(ax,outputs, edge_color='orange', vertex_color='Cyan')
139
+ painter.draw_wireframe(ax,outputs, edge_color='midnightblue', vertex_color='deeppink')
140
+
141
+
142
+
143
+ if __name__ == '__main__':
144
+ main()
line_matching/two_view_pipeline.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A two-view sparse feature matching pipeline.
3
+
4
+ This model contains sub-models for each step:
5
+ feature extraction, feature matching, outlier filtering, pose estimation.
6
+ Each step is optional, and the features or matches can be provided as input.
7
+ Default: SuperPoint with nearest neighbor matching.
8
+
9
+ Convention for the matches: m0[i] is the index of the keypoint in image 1
10
+ that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched.
11
+ """
12
+
13
+ import numpy as np
14
+ import torch
15
+
16
+ from gluestick import get_model
17
+ from gluestick.models.base_model import BaseModel
18
+ from line_matching.wireframe import SPWireframeDescriptor
19
+
20
+
21
+ def keep_quadrant_kp_subset(keypoints, scores, descs, h, w):
22
+ """Keep only keypoints in one of the four quadrant of the image."""
23
+ h2, w2 = h // 2, w // 2
24
+ w_x = np.random.choice([0, w2])
25
+ w_y = np.random.choice([0, h2])
26
+ valid_mask = ((keypoints[..., 0] >= w_x)
27
+ & (keypoints[..., 0] < w_x + w2)
28
+ & (keypoints[..., 1] >= w_y)
29
+ & (keypoints[..., 1] < w_y + h2))
30
+ keypoints = keypoints[valid_mask][None]
31
+ scores = scores[valid_mask][None]
32
+ descs = descs.permute(0, 2, 1)[valid_mask].t()[None]
33
+ return keypoints, scores, descs
34
+
35
+
36
+ def keep_random_kp_subset(keypoints, scores, descs, num_selected):
37
+ """Keep a random subset of keypoints."""
38
+ num_kp = keypoints.shape[1]
39
+ selected_kp = torch.randperm(num_kp)[:num_selected]
40
+ keypoints = keypoints[:, selected_kp]
41
+ scores = scores[:, selected_kp]
42
+ descs = descs[:, :, selected_kp]
43
+ return keypoints, scores, descs
44
+
45
+
46
+ def keep_best_kp_subset(keypoints, scores, descs, num_selected):
47
+ """Keep the top num_selected best keypoints."""
48
+ sorted_indices = torch.sort(scores, dim=1)[1]
49
+ selected_kp = sorted_indices[:, -num_selected:]
50
+ keypoints = torch.gather(keypoints, 1,
51
+ selected_kp[:, :, None].repeat(1, 1, 2))
52
+ scores = torch.gather(scores, 1, selected_kp)
53
+ descs = torch.gather(descs, 2,
54
+ selected_kp[:, None].repeat(1, descs.shape[1], 1))
55
+ return keypoints, scores, descs
56
+
57
+
58
+ class TwoViewPipeline(BaseModel):
59
+ default_conf = {
60
+ 'extractor': {
61
+ 'name': 'superpoint',
62
+ 'trainable': False,
63
+ },
64
+ 'use_lines': False,
65
+ 'use_points': True,
66
+ 'randomize_num_kp': False,
67
+ 'detector': {'name': None},
68
+ 'descriptor': {'name': None},
69
+ 'matcher': {'name': 'nearest_neighbor_matcher'},
70
+ 'filter': {'name': None},
71
+ 'solver': {'name': None},
72
+ 'ground_truth': {
73
+ 'from_pose_depth': False,
74
+ 'from_homography': False,
75
+ 'th_positive': 3,
76
+ 'th_negative': 5,
77
+ 'reward_positive': 1,
78
+ 'reward_negative': -0.25,
79
+ 'is_likelihood_soft': True,
80
+ 'p_random_occluders': 0,
81
+ 'n_line_sampled_pts': 50,
82
+ 'line_perp_dist_th': 5,
83
+ 'overlap_th': 0.2,
84
+ 'min_visibility_th': 0.5
85
+ },
86
+ }
87
+ required_data_keys = ['image0', 'image1']
88
+ strict_conf = False # need to pass new confs to children models
89
+ components = [
90
+ 'extractor', 'detector', 'descriptor', 'matcher', 'filter', 'solver']
91
+
92
+ def _init(self, conf):
93
+ if conf.extractor.name:
94
+ self.extractor = SPWireframeDescriptor(conf.extractor)
95
+
96
+ if conf.matcher.name:
97
+ self.matcher = get_model(conf.matcher.name)(conf.matcher)
98
+ else:
99
+ self.required_data_keys += ['matches0']
100
+
101
+ if conf.filter.name:
102
+ self.filter = get_model(conf.filter.name)(conf.filter)
103
+
104
+ if conf.solver.name:
105
+ self.solver = get_model(conf.solver.name)(conf.solver)
106
+
107
+ def _forward(self, data):
108
+
109
+ def process_siamese(data, i):
110
+ data_i = {k[:-1]: v for k, v in data.items() if k[-1] == i}
111
+ if self.conf.extractor.name:
112
+ pred_i = self.extractor(data_i)
113
+ else:
114
+ pred_i = {}
115
+ if self.conf.detector.name:
116
+ pred_i = self.detector(data_i)
117
+ else:
118
+ for k in ['keypoints', 'keypoint_scores', 'descriptors',
119
+ 'lines', 'line_scores', 'line_descriptors',
120
+ 'valid_lines']:
121
+ if k in data_i:
122
+ pred_i[k] = data_i[k]
123
+ if self.conf.descriptor.name:
124
+ pred_i = {
125
+ **pred_i, **self.descriptor({**data_i, **pred_i})}
126
+ return pred_i
127
+
128
+ pred0 = process_siamese(data, '0')
129
+ pred1 = process_siamese(data, '1')
130
+
131
+ pred = {**{k + '0': v for k, v in pred0.items()},
132
+ **{k + '1': v for k, v in pred1.items()}}
133
+
134
+ if self.conf.matcher.name:
135
+ pred = {**pred, **self.matcher({**data, **pred})}
136
+
137
+ if self.conf.filter.name:
138
+ pred = {**pred, **self.filter({**data, **pred})}
139
+
140
+ if self.conf.solver.name:
141
+ pred = {**pred, **self.solver({**data, **pred})}
142
+
143
+ return pred
144
+
145
+ def loss(self, pred, data):
146
+ losses = {}
147
+ total = 0
148
+ for k in self.components:
149
+ if self.conf[k].name:
150
+ try:
151
+ losses_ = getattr(self, k).loss(pred, {**pred, **data})
152
+ except NotImplementedError:
153
+ continue
154
+ losses = {**losses, **losses_}
155
+ total = losses_['total'] + total
156
+ return {**losses, 'total': total}
157
+
158
+ def metrics(self, pred, data):
159
+ metrics = {}
160
+ for k in self.components:
161
+ if self.conf[k].name:
162
+ try:
163
+ metrics_ = getattr(self, k).metrics(pred, {**pred, **data})
164
+ except NotImplementedError:
165
+ continue
166
+ metrics = {**metrics, **metrics_}
167
+ return metrics
line_matching/wireframe.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from pytlsd import lsd
4
+ from sklearn.cluster import DBSCAN
5
+ import sys
6
+
7
+ from gluestick.models.base_model import BaseModel
8
+ from gluestick.models.superpoint import SuperPoint, sample_descriptors
9
+ from gluestick.geometry import warp_lines_torch
10
+
11
+ from pathlib import Path
12
+ import copy, cv2
13
+ import os, glob
14
+ import scalelsd
15
+ from scalelsd.ssl.models.detector import ScaleLSD
16
+ from scalelsd.ssl.misc.train_utils import fix_seeds, load_scalelsd_model
17
+
18
+
19
+ def lines_to_wireframe(lines, line_scores, all_descs, conf):
20
+ """ Given a set of lines, their score and dense descriptors,
21
+ merge close-by endpoints and compute a wireframe defined by
22
+ its junctions and connectivity.
23
+ Returns:
24
+ junctions: list of [num_junc, 2] tensors listing all wireframe junctions
25
+ junc_scores: list of [num_junc] tensors with the junction score
26
+ junc_descs: list of [dim, num_junc] tensors with the junction descriptors
27
+ connectivity: list of [num_junc, num_junc] bool arrays with True when 2 junctions are connected
28
+ new_lines: the new set of [b_size, num_lines, 2, 2] lines
29
+ lines_junc_idx: a [b_size, num_lines, 2] tensor with the indices of the junctions of each endpoint
30
+ num_true_junctions: a list of the number of valid junctions for each image in the batch,
31
+ i.e. before filling with random ones
32
+ """
33
+ b_size, _, _, _ = all_descs.shape
34
+ device = lines.device
35
+ endpoints = lines.reshape(b_size, -1, 2)
36
+
37
+ (junctions, junc_scores, junc_descs, connectivity, new_lines,
38
+ lines_junc_idx, num_true_junctions) = [], [], [], [], [], [], []
39
+ for bs in range(b_size):
40
+ # Cluster the junctions that are close-by
41
+ db = DBSCAN(eps=conf.nms_radius, min_samples=1).fit(
42
+ endpoints[bs].cpu().numpy())
43
+ clusters = db.labels_
44
+ n_clusters = len(set(clusters))
45
+ num_true_junctions.append(n_clusters)
46
+
47
+ # Compute the average junction and score for each cluster
48
+ clusters = torch.tensor(clusters, dtype=torch.long,
49
+ device=device)
50
+ new_junc = torch.zeros(n_clusters, 2, dtype=torch.float,
51
+ device=device)
52
+ new_junc.scatter_reduce_(0, clusters[:, None].repeat(1, 2),
53
+ endpoints[bs], reduce='mean',
54
+ include_self=False)
55
+ junctions.append(new_junc)
56
+ new_scores = torch.zeros(n_clusters, dtype=torch.float, device=device)
57
+ new_scores.scatter_reduce_(
58
+ 0, clusters, torch.repeat_interleave(line_scores[bs], 2),
59
+ reduce='mean', include_self=False)
60
+ junc_scores.append(new_scores)
61
+
62
+ # Compute the new lines
63
+ new_lines.append(junctions[-1][clusters].reshape(-1, 2, 2))
64
+ lines_junc_idx.append(clusters.reshape(-1, 2))
65
+
66
+ # Compute the junction connectivity
67
+ junc_connect = torch.eye(n_clusters, dtype=torch.bool,
68
+ device=device)
69
+ pairs = clusters.reshape(-1, 2) # these pairs are connected by a line
70
+ junc_connect[pairs[:, 0], pairs[:, 1]] = True
71
+ junc_connect[pairs[:, 1], pairs[:, 0]] = True
72
+ connectivity.append(junc_connect)
73
+
74
+ # Interpolate the new junction descriptors
75
+ junc_descs.append(sample_descriptors(
76
+ junctions[-1][None], all_descs[bs:(bs + 1)], 8)[0])
77
+
78
+ new_lines = torch.stack(new_lines, dim=0)
79
+ lines_junc_idx = torch.stack(lines_junc_idx, dim=0)
80
+ return (junctions, junc_scores, junc_descs, connectivity,
81
+ new_lines, lines_junc_idx, num_true_junctions)
82
+
83
+
84
+ class SPWireframeDescriptor(BaseModel):
85
+ default_conf = {
86
+ 'sp_params': {
87
+ 'has_detector': True,
88
+ 'has_descriptor': True,
89
+ 'descriptor_dim': 256,
90
+ 'trainable': False,
91
+
92
+ # Inference
93
+ 'return_all': True,
94
+ 'sparse_outputs': True,
95
+ 'nms_radius': 4,
96
+ 'detection_threshold': 0.005,
97
+ 'max_num_keypoints': 1000,
98
+ 'force_num_keypoints': True,
99
+ 'remove_borders': 4,
100
+ },
101
+ 'wireframe_params': {
102
+ 'merge_points': True,
103
+ 'merge_line_endpoints': True,
104
+ 'nms_radius': 3,
105
+ 'max_n_junctions': 500,
106
+ },
107
+ 'max_n_lines': 250,
108
+ 'min_length': 15,
109
+ }
110
+ required_data_keys = ['image']
111
+
112
+ def _init(self, conf):
113
+ self.conf = conf
114
+ self.sp = SuperPoint(conf.sp_params)
115
+ self.extr_conf = {}
116
+
117
+ def detect_lsd_lines(self, x, max_n_lines=None):
118
+ if max_n_lines is None:
119
+ max_n_lines = self.conf.max_n_lines
120
+ lines, scores, valid_lines = [], [], []
121
+ for b in range(len(x)):
122
+ # For each image on batch
123
+ img = (x[b].squeeze().cpu().numpy() * 255).astype(np.uint8)
124
+ if max_n_lines is None:
125
+ b_segs = lsd(img)
126
+ else:
127
+ for s in [0.3, 0.4, 0.5, 0.7, 0.8, 1.0]:
128
+ b_segs = lsd(img, scale=s)
129
+ if len(b_segs) >= max_n_lines:
130
+ break
131
+
132
+ segs_length = np.linalg.norm(b_segs[:, 2:4] - b_segs[:, 0:2], axis=1)
133
+ # Remove short lines
134
+ b_segs = b_segs[segs_length >= self.conf.min_length]
135
+ segs_length = segs_length[segs_length >= self.conf.min_length]
136
+ b_scores = b_segs[:, -1] * np.sqrt(segs_length)
137
+ # Take the most relevant segments with
138
+ indices = np.argsort(-b_scores)
139
+ if max_n_lines is not None:
140
+ indices = indices[:max_n_lines]
141
+ lines.append(torch.from_numpy(b_segs[indices, :4].reshape(-1, 2, 2)))
142
+ scores.append(torch.from_numpy(b_scores[indices]))
143
+ valid_lines.append(torch.ones_like(scores[-1], dtype=torch.bool))
144
+
145
+ lines = torch.stack(lines).to(x)
146
+ scores = torch.stack(scores).to(x)
147
+ valid_lines = torch.stack(valid_lines).to(x.device)
148
+ return lines, scores, valid_lines
149
+
150
+ def update_conf(self, conf):
151
+ self.extr_conf = conf
152
+
153
+ def _forward(self, data):
154
+ b_size, _, h, w = data['image'].shape
155
+ device = data['image'].device
156
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
157
+
158
+ if not self.conf.sp_params.force_num_keypoints:
159
+ assert b_size == 1, "Only batch size of 1 accepted for non padded inputs"
160
+
161
+ # Line detection
162
+ if 'lines' not in data or 'line_scores' not in data:
163
+ if self.extr_conf is None:
164
+ ckpt = 'models/scalelsd-vitbase-v1-train-sa1b.pt'
165
+ model = load_scalelsd_model(ckpt, device)
166
+ model.junction_threshold_hm = 0.008
167
+ threshold = 5
168
+ model.num_junctions_inference = 4096
169
+ size = 512
170
+ image = data['image']
171
+ image_size = image.shape[-2:]
172
+ image_np = image[0,0].cpu().numpy()
173
+ image_cp = copy.deepcopy(image_np)
174
+ image_torch = torch.from_numpy(cv2.resize(image_cp, (size, size))).float()
175
+ image_cuda = image_torch[None,None].to(device)
176
+ meta = {
177
+ 'width': image_size[1],
178
+ 'height':image_size[0],
179
+ 'filename': '',
180
+ 'use_lsd': False,
181
+ 'use_nms': False,
182
+ }
183
+ outputs, _ = model(image_cuda, meta)
184
+ lines = outputs[0]['lines_pred']
185
+ line_scores = outputs[0]['lines_score']
186
+ lines = lines[line_scores>=threshold]
187
+ line_scores = line_scores[line_scores>=threshold][None]
188
+ elif self.extr_conf['model_name'] != 'lsd':
189
+ # initialize model
190
+ ckpt = "models/" + self.extr_conf['model_name']
191
+ model = load_scalelsd_model(ckpt, device)
192
+ # set model parameters
193
+ model.junction_threshold_hm = self.extr_conf['junction_threshold_hm']
194
+ model.num_junctions_inference = self.extr_conf['num_junctions_inference']
195
+ width, height = self.extr_conf['width'], self.extr_conf['height']
196
+
197
+ image = data['image']
198
+ image_size = image.shape[-2:]
199
+ image_np = image[0,0].cpu().numpy()
200
+ image_cp = copy.deepcopy(image_np)
201
+ image_torch = torch.from_numpy(cv2.resize(image_cp, (width, height))).float()
202
+ image_cuda = image_torch[None,None].to(device)
203
+ meta = {
204
+ 'width': image_size[1],
205
+ 'height':image_size[0],
206
+ 'filename': '',
207
+ 'use_lsd': self.extr_conf['use_lsd'],
208
+ 'use_nms': self.extr_conf['use_nms'],
209
+ }
210
+ outputs, _ = model(image_cuda, meta)
211
+ lines = outputs[0]['lines_pred']
212
+ line_scores = outputs[0]['lines_score']
213
+ lines = lines[line_scores>=self.extr_conf['threshold']]
214
+ line_scores = line_scores[line_scores>=self.extr_conf['threshold']][None]
215
+ else:
216
+ if 'original_img' in data:
217
+ # Detect more lines, because when projecting them to the image most of them will be discarded
218
+ lines, line_scores, valid_lines = self.detect_lsd_lines(
219
+ data['original_img'], self.conf.max_n_lines * 3)
220
+ # Apply the same transformation that is applied in homography_adaptation
221
+ lines, valid_lines2 = warp_lines_torch(lines, data['H'], False, data['image'].shape[-2:])
222
+ valid_lines = valid_lines & valid_lines2
223
+ lines[~valid_lines] = -1
224
+ line_scores[~valid_lines] = 0
225
+ # Re-sort the line segments to pick the ones that are inside the image and have bigger score
226
+ sorted_scores, sorting_indices = torch.sort(line_scores, dim=-1, descending=True)
227
+ line_scores = sorted_scores[:, :self.conf.max_n_lines]
228
+ sorting_indices = sorting_indices[:, :self.conf.max_n_lines]
229
+ lines = torch.take_along_dim(lines, sorting_indices[..., None, None], 1)
230
+ valid_lines = torch.take_along_dim(valid_lines, sorting_indices, 1)
231
+ else:
232
+ lines, line_scores, valid_lines = self.detect_lsd_lines(data['image'],max_n_lines=1000000)
233
+
234
+ else:
235
+ lines, line_scores, valid_lines = data['lines'], data['line_scores'], data['valid_lines']
236
+ if line_scores.shape[-1] != 0:
237
+ line_scores /= (line_scores.new_tensor(1e-8) + line_scores.max(dim=1).values[:, None])
238
+
239
+ # SuperPoint prediction
240
+ pred = self.sp(data)
241
+
242
+ # Remove keypoints that are too close to line endpoints
243
+ if self.conf.wireframe_params.merge_points:
244
+ kp = pred['keypoints']
245
+ line_endpts = lines.reshape(b_size, -1, 2)
246
+ dist_pt_lines = torch.norm(
247
+ kp[:, :, None] - line_endpts[:, None], dim=-1)
248
+ # For each keypoint, mark it as valid or to remove
249
+ pts_to_remove = torch.any(
250
+ dist_pt_lines < self.conf.sp_params.nms_radius, dim=2)
251
+ # Simply remove them (we assume batch_size = 1 here)
252
+ assert len(kp) == 1
253
+ pred['keypoints'] = pred['keypoints'][0][~pts_to_remove[0]][None]
254
+ pred['keypoint_scores'] = pred['keypoint_scores'][0][~pts_to_remove[0]][None]
255
+ pred['descriptors'] = pred['descriptors'][0].T[~pts_to_remove[0]].T[None]
256
+
257
+ # Connect the lines together to form a wireframe
258
+ orig_lines = lines.clone()
259
+ if self.conf.wireframe_params.merge_line_endpoints and len(lines[0]) > 0:
260
+ # Merge first close-by endpoints to connect lines
261
+ (line_points, line_pts_scores, line_descs, line_association,
262
+ lines, lines_junc_idx, num_true_junctions) = lines_to_wireframe(
263
+ lines, line_scores, pred['all_descriptors'],
264
+ conf=self.conf.wireframe_params)
265
+
266
+ # Add the keypoints to the junctions and fill the rest with random keypoints
267
+ (all_points, all_scores, all_descs,
268
+ pl_associativity) = [], [], [], []
269
+ for bs in range(b_size):
270
+ all_points.append(torch.cat(
271
+ [line_points[bs], pred['keypoints'][bs]], dim=0))
272
+ all_scores.append(torch.cat(
273
+ [line_pts_scores[bs], pred['keypoint_scores'][bs]], dim=0))
274
+ all_descs.append(torch.cat(
275
+ [line_descs[bs], pred['descriptors'][bs]], dim=1))
276
+
277
+ associativity = torch.eye(len(all_points[-1]), dtype=torch.bool, device=device)
278
+ associativity[:num_true_junctions[bs], :num_true_junctions[bs]] = \
279
+ line_association[bs][:num_true_junctions[bs], :num_true_junctions[bs]]
280
+ pl_associativity.append(associativity)
281
+
282
+ all_points = torch.stack(all_points, dim=0)
283
+ all_scores = torch.stack(all_scores, dim=0)
284
+ all_descs = torch.stack(all_descs, dim=0)
285
+ pl_associativity = torch.stack(pl_associativity, dim=0)
286
+ else:
287
+ # Lines are independent
288
+ all_points = torch.cat([lines.reshape(b_size, -1, 2),
289
+ pred['keypoints']], dim=1)
290
+ n_pts = all_points.shape[1]
291
+ num_lines = lines.shape[1]
292
+ num_true_junctions = [num_lines * 2] * b_size
293
+ all_scores = torch.cat([
294
+ torch.repeat_interleave(line_scores, 2, dim=1),
295
+ pred['keypoint_scores']], dim=1)
296
+ pred['line_descriptors'] = self.endpoints_pooling(
297
+ lines, pred['all_descriptors'], (h, w))
298
+ all_descs = torch.cat([
299
+ pred['line_descriptors'].reshape(b_size, self.conf.sp_params.descriptor_dim, -1),
300
+ pred['descriptors']], dim=2)
301
+ pl_associativity = torch.eye(
302
+ n_pts, dtype=torch.bool,
303
+ device=device)[None].repeat(b_size, 1, 1)
304
+ lines_junc_idx = torch.arange(
305
+ num_lines * 2, device=device).reshape(1, -1, 2).repeat(b_size, 1, 1)
306
+
307
+ del pred['all_descriptors'] # Remove dense descriptors to save memory
308
+ torch.cuda.empty_cache()
309
+
310
+ return {'keypoints': all_points,
311
+ 'keypoint_scores': all_scores,
312
+ 'descriptors': all_descs,
313
+ 'pl_associativity': pl_associativity,
314
+ 'num_junctions': torch.tensor(num_true_junctions),
315
+ 'lines': lines,
316
+ 'orig_lines': orig_lines,
317
+ 'lines_junc_idx': lines_junc_idx,
318
+ 'line_scores': line_scores,
319
+ # 'valid_lines': valid_lines,
320
+ }
321
+
322
+ @staticmethod
323
+ def endpoints_pooling(segs, all_descriptors, img_shape):
324
+ assert segs.ndim == 4 and segs.shape[-2:] == (2, 2)
325
+ filter_shape = all_descriptors.shape[-2:]
326
+ scale_x = filter_shape[1] / img_shape[1]
327
+ scale_y = filter_shape[0] / img_shape[0]
328
+
329
+ scaled_segs = torch.round(segs * torch.tensor([scale_x, scale_y]).to(segs)).long()
330
+ scaled_segs[..., 0] = torch.clip(scaled_segs[..., 0], 0, filter_shape[1] - 1)
331
+ scaled_segs[..., 1] = torch.clip(scaled_segs[..., 1], 0, filter_shape[0] - 1)
332
+ line_descriptors = [all_descriptors[None, b, ..., torch.squeeze(b_segs[..., 1]), torch.squeeze(b_segs[..., 0])]
333
+ for b, b_segs in enumerate(scaled_segs)]
334
+ line_descriptors = torch.cat(line_descriptors)
335
+ return line_descriptors # Shape (1, 256, 308, 2)
336
+
337
+ def loss(self, pred, data):
338
+ raise NotImplementedError
339
+
340
+ def metrics(self, pred, data):
341
+ return {}
predictor/predict.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+ import os
5
+ import os.path as osp
6
+ import glob
7
+ from tqdm import tqdm
8
+
9
+ from scalelsd.base import setup_logger, MetricLogger, show, WireframeGraph
10
+
11
+ from scalelsd.ssl.datasets import dataset_util
12
+ from scalelsd.ssl.models.detector import ScaleLSD
13
+ from scalelsd.ssl.misc.train_utils import load_scalelsd_model
14
+
15
+ from torch.utils.data import DataLoader
16
+ import torch.utils.data.dataloader as torch_loader
17
+
18
+ from pathlib import Path
19
+ import argparse, yaml, logging, time, datetime, cv2, copy, sys, json
20
+ from easydict import EasyDict
21
+ import accelerate
22
+ from accelerate import load_checkpoint_and_dispatch
23
+ import matplotlib
24
+ import matplotlib.pyplot as plt
25
+
26
+ def parse_args():
27
+ aparser = argparse.ArgumentParser()
28
+ aparser.add_argument('-c', '--ckpt', default='models/scalelsd-vitbase-v1-train-sa1b.pt', type=str, help='the path for loading checkpoints')
29
+ aparser.add_argument('-t','--threshold', default=10,type=float)
30
+ aparser.add_argument('-i', '--img', required=True, type=str)
31
+ aparser.add_argument('--width', default=512, type=int)
32
+ aparser.add_argument('--height', default=512,type=int)
33
+ aparser.add_argument('--whitebg', default=0.0, type=float)
34
+ aparser.add_argument('--saveto', default=None, type=str,)
35
+ aparser.add_argument('-e','--ext', default='pdf', type=str, choices=['pdf','png','json','txt'])
36
+ aparser.add_argument('--device', default='cuda', type=str, choices=['cuda','cpu','mps'])
37
+ aparser.add_argument('--disable-show', default=False, action='store_true')
38
+ aparser.add_argument('--draw-junctions-only', default=False, action='store_true')
39
+ aparser.add_argument('--use_lsd', default=False, action='store_true')
40
+ aparser.add_argument('--use_nms', default=False, action='store_true')
41
+
42
+ ScaleLSD.cli(aparser)
43
+
44
+ args = aparser.parse_args()
45
+
46
+ ScaleLSD.configure(args)
47
+
48
+ return args
49
+
50
+
51
+ def main():
52
+ args = parse_args()
53
+
54
+ model = load_scalelsd_model(args.ckpt, device=args.device)
55
+
56
+ # Set up output directory and painter
57
+ if args.saveto is None:
58
+ print('No output directory specified, saving outputs to folder: temp_output/ScaleLSD')
59
+ args.saveto = 'temp_output/ScaleLSD'
60
+ os.makedirs(args.saveto,exist_ok=True)
61
+
62
+ show.painters.HAWPainter.confidence_threshold = args.threshold
63
+ # show.painters.HAWPainter.line_width = 2
64
+ # show.painters.HAWPainter.marker_size = 4
65
+ show.Canvas.show = not args.disable_show
66
+ if args.whitebg > 0.0:
67
+ show.Canvas.white_overlay = args.whitebg
68
+ painter = show.painters.HAWPainter()
69
+ edge_color = 'orange' # 'midnightblue'
70
+ vertex_color = 'Cyan' # 'deeppink'
71
+
72
+ # Prepare images
73
+ all_images = []
74
+ if os.path.isfile(args.img) and args.img.endswith(('.jpg', '.png')):
75
+ all_images.append(args.img)
76
+ elif os.path.isdir(args.img):
77
+ for file in os.listdir(args.img):
78
+ if file.endswith(('.jpg', '.png')):
79
+ fname = os.path.join(args.img, file)
80
+ all_images.append(fname)
81
+ all_images = sorted(all_images)
82
+ else:
83
+ raise ValueError('Input must be a file or a directory containing images.')
84
+
85
+ # Inference
86
+ for fname in tqdm(all_images):
87
+ pname = Path(fname)
88
+ image = cv2.imread(fname,0)
89
+
90
+ # for resize input, default shape is [512, 512]
91
+ ori_shape = image.shape[:2]
92
+ image_cp = copy.deepcopy(image)
93
+ image_ = cv2.resize(image_cp, (args.width, args.height))
94
+ image_ = torch.from_numpy(image_).float()/255.0
95
+ image_ = image_[None,None].to(args.device)
96
+
97
+ meta = {
98
+ 'width': ori_shape[1],
99
+ 'height':ori_shape[0],
100
+ 'filename': '',
101
+ 'use_lsd': args.use_lsd,
102
+ 'use_nms': args.use_nms,
103
+ }
104
+
105
+ with torch.no_grad():
106
+ outputs, _ = model(image_, meta)
107
+ outputs = outputs[0]
108
+
109
+
110
+ if args.saveto is not None:
111
+
112
+ if args.ext in ['png', 'pdf']:
113
+ fig_file = osp.join(args.saveto, pname.with_suffix('.'+args.ext).name)
114
+ with show.image_canvas(fname, fig_file=fig_file) as ax:
115
+ if args.draw_junctions_only:
116
+ painter.draw_junctions(ax,outputs)
117
+ else:
118
+ # painter.draw_wireframe(ax,outputs)
119
+ painter.draw_wireframe(ax,outputs, edge_color=edge_color, vertex_color=vertex_color)
120
+ elif args.ext == 'json':
121
+ indices = WireframeGraph.xyxy2indices(outputs['juncs_pred'],outputs['lines_pred'])
122
+ wireframe = WireframeGraph(outputs['juncs_pred'], outputs['juncs_score'], indices, outputs['lines_score'], outputs['width'], outputs['height'])
123
+ outpath = osp.join(args.saveto, pname.with_suffix('.json').name)
124
+ with open(outpath,'w') as f:
125
+ json.dump(wireframe.jsonize(),f)
126
+ else:
127
+ raise ValueError('Unsupported extension: {} is not in [png, pdf, json]'.format(args.ext))
128
+
129
+
130
+ if __name__ == "__main__":
131
+ main()
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ opencv-python
3
+ cython
4
+ matplotlib
5
+ yacs
6
+ scikit-image
7
+ tqdm
8
+ python-json-logger
9
+ h5py
10
+ shapely
11
+ pycolmap
12
+ seaborn
13
+ kornia
14
+ easydict
15
+ pynvml
16
+ timm
17
+ einops==0.7.0
18
+ numpy==1.26.4
19
+ gradio
20
+ pydantic==2.10.6
21
+ pytlsd@git+https://github.com/iago-suarez/pytlsd.git@4180ab8
scalelsd/.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ */__pycache__/
3
+ **/__pycache__/
4
+
5
+ data-ssl
6
+ exp
7
+ exp-ssl
8
+ temp_output
9
+ third_party
10
+ ./models
scalelsd/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import base
2
+ from . import ssl
scalelsd/base/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .csrc import _C
2
+ from . import utils
3
+ from .utils.logger import setup_logger
4
+ from .utils.metric_logger import MetricLogger
5
+ from .wireframe import WireframeGraph
6
+
7
+ __all__ = [
8
+ "_C",
9
+ "utils",
10
+ "setup_logger",
11
+ "MetricLogger",
12
+ "WireframeGraph",
13
+ ]
scalelsd/base/csrc/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.cpp_extension import load
2
+ import glob
3
+ import os.path as osp
4
+
5
+ __this__ = osp.dirname(__file__)
6
+
7
+ try:
8
+ _C = load(name='_C',sources=[
9
+ osp.join(__this__,'binding.cpp'),
10
+ osp.join(__this__,'linesegment.cu'),
11
+ ]
12
+ )
13
+ except:
14
+ _C = None
15
+
16
+ _C = load(name='_C', sources=[osp.join(__this__,'binding.cpp'), osp.join(__this__,'linesegment.cu')])
17
+ __all__ = ["_C"]
18
+
19
+ #_C = load(name='base._C', sources=['lltm_cuda.cpp', 'lltm_cuda_kernel.cu'])
scalelsd/base/csrc/binding.cpp ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #include "linesegment.h"
2
+
3
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4
+ m.def("encodels", &encodels, "Encoding line segments to maps");
5
+ }
scalelsd/base/csrc/linesegment.cu ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+ #include <ATen/cuda/CUDAContext.h>
3
+
4
+ // #include <THC/THC.h>
5
+ // #include <THC/THCDeviceUtils.cuh>
6
+ #include <torch/torch.h>
7
+ #include <torch/extension.h>
8
+
9
+ #include <vector>
10
+ #include <iostream>
11
+
12
+ int const CUDA_NUM_THREADS = 1024;
13
+
14
+ inline int CUDA_GET_BLOCKS(const int N) {
15
+ return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
16
+ }
17
+
18
+ #define CUDA_1D_KERNEL_LOOP(i, n) \
19
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
20
+ i += blockDim.x * gridDim.x)
21
+
22
+
23
+ __global__ void encode_kernel(const int nthreads, const float* lines,
24
+ const int input_height, const int input_width, const int num,
25
+ const int height, const int width, float* map,
26
+ bool* label, float* tmap)
27
+ {
28
+ CUDA_1D_KERNEL_LOOP(index, nthreads){
29
+ int w = index % width;
30
+ int h = (index / width) % height;
31
+ int x_index = h*width + w;
32
+ int y_index = height*width + h*width + w;
33
+ int ux_index = 2*height*width + h*width + w;
34
+ int uy_index = 3*height*width + h*width + w;
35
+ int vx_index = 4*height*width + h*width + w;
36
+ int vy_index = 5*height*width + h*width + w;
37
+ int label_index = h*width + w;
38
+
39
+ float px = (float) w;
40
+ float py = (float) h;
41
+ float min_dis = 1e30;
42
+ int minp = -1;
43
+ bool flagp = true;
44
+ for(int i = 0; i < num; ++i) {
45
+ float xs = (float)width /(float)input_width;
46
+ float ys = (float)height /(float)input_height;
47
+ float x1 = lines[4*i ]*xs;
48
+ float y1 = lines[4*i+1]*ys;
49
+ float x2 = lines[4*i+2]*xs;
50
+ float y2 = lines[4*i+3]*ys;
51
+
52
+ float dx = x2 - x1;
53
+ float dy = y2 - y1;
54
+ float ux = x1 - px;
55
+ float uy = y1 - py;
56
+ float vx = x2 - px;
57
+ float vy = y2 - py;
58
+ float norm2 = dx*dx + dy*dy;
59
+ bool flag = false;
60
+ float t = ((px-x1)*dx + (py-y1)*dy)/(norm2+1e-6);
61
+ if (t<=1 && t>=0.0)
62
+ flag = true;
63
+
64
+ t = t<0.0? 0.0:t;
65
+ t = t>1.0? 1.0:t;
66
+
67
+ float ax = x1 + t*(x2-x1) - px;
68
+ float ay = y1 + t*(y2-y1) - py;
69
+
70
+ float dis = ax*ax + ay*ay;
71
+ if (dis < min_dis) {
72
+ min_dis = dis;
73
+ map[x_index] = ax;
74
+ map[y_index] = ay;
75
+ float norm_u2 = ux*ux+uy*uy;
76
+ float norm_v2 = vx*vx+vy*vy;
77
+
78
+ if (norm_u2 < norm_v2){
79
+ map[ux_index] = ux;
80
+ map[uy_index] = uy;
81
+ map[vx_index] = vx;
82
+ map[vy_index] = vy;
83
+ }
84
+ else{
85
+ map[ux_index] = vx;
86
+ map[uy_index] = vy;
87
+ map[vx_index] = ux;
88
+ map[vy_index] = uy;
89
+ }
90
+
91
+ minp = i;
92
+ if (flag)
93
+ flagp = true;
94
+ else
95
+ flagp = false;
96
+
97
+ tmap[index] = t;
98
+ }
99
+ }
100
+ // label[label_index+minp*height*width] = flagp;
101
+
102
+ }
103
+ }
104
+
105
+
106
+ std::tuple<at::Tensor, at::Tensor, at::Tensor> lsencode_cuda(
107
+ const at::Tensor& lines,
108
+ const int input_height,
109
+ const int input_width,
110
+ const int height,
111
+ const int width,
112
+ const int num_lines)
113
+
114
+ {
115
+ auto map = at::zeros({6,height,width}, lines.options());
116
+ auto tmap = at::zeros({1,height,width}, lines.options());
117
+ auto label = at::zeros({1,height,width}, lines.options().dtype(at::kBool));
118
+ auto nthreads = height*width;
119
+
120
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
121
+
122
+ float* map_data = map.data<float>();
123
+ float* tmap_data = tmap.data<float>();
124
+ bool* label_data = label.data<bool>();
125
+
126
+ encode_kernel<<<CUDA_GET_BLOCKS(nthreads), CUDA_NUM_THREADS >>>(
127
+ nthreads,
128
+ lines.contiguous().data<float>(),
129
+ input_height, input_width,
130
+ num_lines,
131
+ height, width,
132
+ map_data,
133
+ label_data,
134
+ tmap_data);
135
+
136
+ // THCudaCheck(cudaGetLastError());
137
+
138
+ return std::make_tuple(map, label, tmap);
139
+ }
scalelsd/base/csrc/linesegment.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // #pragma once
2
+ #include <torch/extension.h>
3
+
4
+ std::tuple<at::Tensor, at::Tensor, at::Tensor> lsencode_cuda(
5
+ const at::Tensor& lines,
6
+ const int input_height,
7
+ const int input_width,
8
+ const int height,
9
+ const int width,
10
+ const int num_lines);
11
+
12
+ std::tuple<at::Tensor,at::Tensor,at::Tensor> encodels(
13
+ const at::Tensor& lines,
14
+ const int input_height,
15
+ const int input_width,
16
+ const int height,
17
+ const int width,
18
+ const int num_lines)
19
+ {
20
+ return lsencode_cuda(lines,
21
+ input_height,
22
+ input_width,
23
+ height,
24
+ width,
25
+ num_lines);
26
+ }
scalelsd/base/show/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .canvas import Canvas, image_canvas, canvas
2
+ from .painters import HAWPainter
3
+ from .cli import cli, configure
scalelsd/base/show/canvas.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ import logging
3
+ import os
4
+
5
+ from matplotlib.pyplot import figimage, margins
6
+ import numpy as np
7
+ import cv2
8
+
9
+ try:
10
+ import matplotlib.pyplot as plt # pylint: disable=import-error
11
+
12
+ except ModuleNotFoundError as err:
13
+ if err.name != 'matplotlib':
14
+ raise err
15
+ plt = None
16
+
17
+
18
+ LOG = logging.getLogger(__name__)
19
+
20
+ class Canvas:
21
+ """Canvas for plotting.
22
+ All methods expose Axes objects. To get Figure objects, you can ask the axis
23
+ `ax.get_figure()`.
24
+ """
25
+
26
+ all_images_directory = None
27
+ all_images_count = 0
28
+ show = False
29
+ image_width = 7.0
30
+ image_height = None
31
+ blank_dpi = 200
32
+ image_dpi_factor = 1.0
33
+ image_min_dpi = 50.0
34
+ out_file_extension = 'pdf'
35
+ white_overlay = False
36
+
37
+ @classmethod
38
+ def generic_name(cls):
39
+ if cls.all_images_directory is None:
40
+ return None
41
+ os.makedirs(cls.all_images_directory, exist_ok=True)
42
+
43
+ cls.all_images_count += 1
44
+ return os.path.join(cls.all_images_directory,
45
+ '{:04}.{}'.format(cls.all_images_count, cls.out_file_extension))
46
+
47
+ @classmethod
48
+ @contextmanager
49
+ def blank(cls, fig_file=None, *, dpi=None, nomargin=False, **kwargs):
50
+ if plt is None:
51
+ raise Exception('please install matplotlib')
52
+ if fig_file is None:
53
+ fig_file = cls.generic_name()
54
+
55
+ if dpi is None:
56
+ dpi = cls.blank_dpi
57
+
58
+ if 'figsize' not in kwargs:
59
+ kwargs['figsize'] = (10, 6)
60
+
61
+ if nomargin:
62
+ if 'gridspec_kw' not in kwargs:
63
+ kwargs['gridspec_kw'] = {}
64
+ kwargs['gridspec_kw']['wspace'] = 0
65
+ kwargs['gridspec_kw']['hspace'] = 0
66
+ kwargs['gridspec_kw']['left'] = 0.0
67
+ kwargs['gridspec_kw']['right'] = 1.0
68
+ kwargs['gridspec_kw']['top'] = 1.0
69
+ kwargs['gridspec_kw']['bottom'] = 0.0
70
+
71
+ fig, ax = plt.subplots(dpi=dpi, **kwargs)
72
+
73
+ yield ax
74
+
75
+ fig.set_tight_layout(not margins)
76
+ if fig_file:
77
+ LOG.debug('writing image to %s', fig_file)
78
+ fig.savefig(fig_file)
79
+
80
+ if cls.show:
81
+ plt.show()
82
+ plt.close(fig)
83
+
84
+
85
+ @classmethod
86
+ @contextmanager
87
+ def image(cls, image, fig_file=None, *, margin=None, **kwargs):
88
+ if plt is None:
89
+ raise Exception('please install matplotlib')
90
+ if fig_file is None:
91
+ fig_file = cls.generic_name()
92
+
93
+ if isinstance(image, str):
94
+ image = cv2.imread(image)[...,::-1]
95
+ else:
96
+ image = np.asarray(image)
97
+
98
+ if margin is None:
99
+ margin = [0.0, 0.0, 0.0, 0.0]
100
+ elif isinstance(margin, float):
101
+ margin = [margin, margin, margin, margin]
102
+ assert len(margin) == 4
103
+
104
+ if 'figsize' not in kwargs:
105
+ # compute figure size: use image ratio and take the drawable area
106
+ # into account that is left after subtracting margins.
107
+ image_ratio = image.shape[0] / image.shape[1]
108
+ image_area_ratio = (1.0 - margin[1] - margin[3]) / (1.0 - margin[0] - margin[2])
109
+ if cls.image_width is not None:
110
+ kwargs['figsize'] = (
111
+ cls.image_width,
112
+ cls.image_width * image_ratio / image_area_ratio
113
+ )
114
+ elif cls.image_height:
115
+ kwargs['figsize'] = (
116
+ cls.image_height * image_area_ratio / image_ratio,
117
+ cls.image_height
118
+ )
119
+
120
+ # dpi = max(cls.image_min_dpi, image.shape[1] / kwargs['figsize'][0] * cls.image_dpi_factor)
121
+ dpi = 200
122
+ # import pdb; pdb.set_trace()
123
+ fig = plt.figure(dpi=dpi, **kwargs)
124
+ ax = plt.Axes(fig, [0.0 + margin[0],
125
+ 0.0 + margin[1],
126
+ 1.0 - margin[2],
127
+ 1.0 - margin[3]])
128
+
129
+ ax.set_axis_off()
130
+ ax.set_xlim(-0.5, image.shape[1] - 0.5) # imshow uses center-pixel-coordinates
131
+ ax.set_ylim(image.shape[0] - 0.5, -0.5)
132
+ fig.add_axes(ax)
133
+ ax.imshow(image)
134
+ if cls.white_overlay:
135
+ white_screen(ax, cls.white_overlay)
136
+ yield ax
137
+
138
+ if fig_file:
139
+ LOG.debug('writing image to %s', fig_file)
140
+ fig.savefig(fig_file)
141
+ if cls.show:
142
+ plt.show()
143
+ import pdb;pdb.set_trace()
144
+ plt.close(fig)
145
+
146
+ def white_screen(ax, alpha=0.9):
147
+ ax.add_patch(
148
+ plt.Rectangle((0, 0), 1, 1, transform=ax.transAxes, alpha=alpha,
149
+ facecolor='white')
150
+ )
151
+
152
+ canvas = Canvas.blank
153
+ image_canvas = Canvas.image
scalelsd/base/show/cli.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from hawp.config import defaults
2
+ import logging
3
+
4
+ from .canvas import Canvas
5
+ from .painters import HAWPainter
6
+ import matplotlib
7
+ LOG = logging.getLogger(__name__)
8
+
9
+ def cli(parser):
10
+ group = parser.add_argument_group('show')
11
+
12
+ assert not Canvas.show
13
+ group.add_argument('--show', default=False,action='store_true',
14
+ help='show every plot, i.e., call matplotlib show()')
15
+
16
+ group.add_argument('--edge-threshold', default=None, type=float,
17
+ help='show the wireframe edges whose confidences are greater than [edge_threshold]')
18
+ group.add_argument('--out-ext', default='png', type=str,
19
+ help='save the plot in specific format')
20
+ def configure(args):
21
+ Canvas.show = args.show
22
+ Canvas.out_file_extension = args.out_ext
23
+ if args.edge_threshold is not None:
24
+ HAWPainter.confidence_threshold = args.edge_threshold
scalelsd/base/show/painters.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ try:
8
+ import matplotlib
9
+ import matplotlib.animation
10
+ import matplotlib.collections
11
+ import matplotlib.patches
12
+ except ImportError:
13
+ matplotlib = None
14
+
15
+
16
+ LOG = logging.getLogger(__name__)
17
+
18
+
19
+ class HAWPainter:
20
+ # line_width = None
21
+ # marker_size = None
22
+ line_width = 2
23
+ marker_size = 4
24
+
25
+ confidence_threshold = 0.05
26
+
27
+ def __init__(self):
28
+
29
+ if self.line_width is None:
30
+ self.line_width = 1
31
+
32
+ if self.marker_size is None:
33
+ self.marker_size = max(1, int(self.line_width * 0.5))
34
+
35
+ def draw_junctions(self, ax, wireframe, *,
36
+ edge_color = None, vertex_color = None):
37
+ if wireframe is None:
38
+ return
39
+
40
+ if edge_color is None:
41
+ edge_color = 'b'
42
+ if vertex_color is None:
43
+ vertex_color = 'c'
44
+
45
+ if 'lines_score' in wireframe.keys():
46
+ line_segments = wireframe['lines_pred'][wireframe['lines_score']>self.confidence_threshold]
47
+ else:
48
+ line_segments = wireframe['lines_pred']
49
+
50
+ if isinstance(line_segments, torch.Tensor):
51
+ line_segments = line_segments.cpu().numpy()
52
+
53
+ ax.plot(line_segments[:,0],line_segments[:,1],'.',color=vertex_color)
54
+ ax.plot(line_segments[:,2],line_segments[:,3],'.',
55
+ color=vertex_color)
56
+ def draw_wireframe(self, ax, wireframe, *,
57
+ edge_color = None, vertex_color = None):
58
+ if wireframe is None:
59
+ return
60
+
61
+ if edge_color is None:
62
+ edge_color = 'b'
63
+ if vertex_color is None:
64
+ vertex_color = 'c'
65
+
66
+ if 'lines_score' in wireframe.keys():
67
+ line_segments = wireframe['lines_pred'][wireframe['lines_score']>self.confidence_threshold]
68
+ else:
69
+ line_segments = wireframe['lines_pred']
70
+
71
+ # import pdb;pdb.set_trace()
72
+ if isinstance(line_segments, torch.Tensor):
73
+ line_segments = line_segments.cpu().numpy()
74
+
75
+ # import pdb;pdb.set_trace()
76
+ # line_segments = wireframe.line_segments(threshold=self.confidence_threshold)
77
+ # line_segments = line_segments.cpu().numpy()
78
+ ax.plot([line_segments[:,0],line_segments[:,2]],[line_segments[:,1],line_segments[:,3]],'-',color=edge_color,linewidth=self.line_width)
79
+ ax.plot(line_segments[:,0],line_segments[:,1],'.',color=vertex_color,markersize=self.marker_size)
80
+ ax.plot(line_segments[:,2],line_segments[:,3],'.',color=vertex_color,markersize=self.marker_size)
scalelsd/base/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
scalelsd/base/utils/logger.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ import logging
3
+ import os
4
+ import sys
5
+ from pythonjsonlogger import jsonlogger
6
+
7
+
8
+ def setup_logger(name, save_dir, out_file='log.txt', json_format=False, rank=0):
9
+ logger = logging.getLogger(name)
10
+ logger.setLevel(logging.DEBUG)
11
+
12
+ if json_format:
13
+ formatter = jsonlogger.JsonFormatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
14
+ else:
15
+ formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
16
+
17
+ if rank == 0:
18
+ ch = logging.StreamHandler(stream=sys.stdout)
19
+ ch.setLevel(logging.DEBUG)
20
+ ch.setFormatter(formatter)
21
+ logger.addHandler(ch)
22
+
23
+ if save_dir:
24
+ os.makedirs(save_dir, exist_ok=True)
25
+ fh = logging.FileHandler(os.path.join(save_dir, out_file))
26
+ fh.setLevel(logging.DEBUG)
27
+ fh.setFormatter(formatter)
28
+ logger.addHandler(fh)
29
+
30
+ return logger
scalelsd/base/utils/metric_logger.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ from collections import defaultdict
3
+ from collections import deque
4
+
5
+ import torch
6
+
7
+
8
+ class SmoothedValue(object):
9
+ """Track a series of values and provide access to smoothed values over a
10
+ window or the global series average.
11
+ """
12
+
13
+ def __init__(self, window_size=20):
14
+ self.deque = deque(maxlen=window_size)
15
+ self.series = []
16
+ self.total = 0.0
17
+ self.count = 0
18
+
19
+ def update(self, value):
20
+ self.deque.append(value)
21
+ self.series.append(value)
22
+ self.count += 1
23
+ self.total += value
24
+
25
+ @property
26
+ def median(self):
27
+ d = torch.tensor(list(self.deque))
28
+ return d.median().item()
29
+
30
+ @property
31
+ def avg(self):
32
+ d = torch.tensor(list(self.deque))
33
+ return d.mean().item()
34
+
35
+ @property
36
+ def global_avg(self):
37
+ return self.total / self.count
38
+
39
+
40
+ class MetricLogger(object):
41
+ def __init__(self, delimiter="\t"):
42
+ self.meters = defaultdict(SmoothedValue)
43
+ self.delimiter = delimiter
44
+
45
+ def update(self, **kwargs):
46
+ for k, v in kwargs.items():
47
+ if isinstance(v, torch.Tensor):
48
+ v = v.item()
49
+ assert isinstance(v, (float, int))
50
+ self.meters[k].update(v)
51
+
52
+ def __getattr__(self, attr):
53
+ if attr in self.meters:
54
+ return self.meters[attr]
55
+ if attr in self.__dict__:
56
+ return self.__dict__[attr]
57
+ raise AttributeError("'{}' object has no attribute '{}'".format(
58
+ type(self).__name__, attr))
59
+
60
+ def __str__(self):
61
+ loss_str = []
62
+ keys = sorted(self.meters)
63
+ # for name, meter in self.meters.items():
64
+ for name in keys:
65
+ meter = self.meters[name]
66
+ loss_str.append(
67
+ "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg)
68
+ )
69
+ return self.delimiter.join(loss_str)
70
+
71
+ def tensorborad(self, iteration, writter, phase='train'):
72
+ for name, meter in self.meters.items():
73
+ if 'loss' in name:
74
+ # writter.add_scalar('average/{}'.format(name), meter.avg, iteration)
75
+ writter.add_scalar('{}/global/{}'.format(phase,name), meter.global_avg, iteration)
76
+ # writter.add_scalar('median/{}'.format(name), meter.median, iteration)
77
+
scalelsd/base/wireframe.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ import json
6
+
7
+ class WireframeGraph:
8
+ def __init__(self,
9
+ vertices: torch.Tensor,
10
+ v_confidences: torch.Tensor,
11
+ edges: torch.Tensor,
12
+ edge_weights: torch.Tensor,
13
+ frame_width: int,
14
+ frame_height: int):
15
+ self.vertices = vertices
16
+ self.v_confidences = v_confidences
17
+ self.edges = edges
18
+ self.weights = edge_weights
19
+ self.frame_width = frame_width
20
+ self.frame_height = frame_height
21
+
22
+ @classmethod
23
+ def xyxy2indices(cls,junctions, lines):
24
+ # junctions: (N,2)
25
+ # lines: (M,4)
26
+ # return: (M,2)
27
+ dist1 = torch.norm(junctions[None,:,:]-lines[:,None,:2],dim=-1)
28
+ dist2 = torch.norm(junctions[None,:,:]-lines[:,None,2:],dim=-1)
29
+ idx1 = torch.argmin(dist1,dim=-1)
30
+ idx2 = torch.argmin(dist2,dim=-1)
31
+ return torch.stack((idx1,idx2),dim=-1)
32
+ @classmethod
33
+ def load_json(cls, fname):
34
+ with open(fname,'r') as f:
35
+ data = json.load(f)
36
+
37
+
38
+ vertices = torch.tensor(data['vertices'])
39
+ v_confidences = torch.tensor(data['vertices-score'])
40
+ edges = torch.tensor(data['edges'])
41
+ edge_weights = torch.tensor(data['edges-weights'])
42
+ height = data['height']
43
+ width = data['width']
44
+
45
+ return WireframeGraph(vertices,v_confidences,edges,edge_weights,width,height)
46
+
47
+ @property
48
+ def is_empty(self):
49
+ for key, val in self.__dict__.items():
50
+ if val is None:
51
+ return True
52
+ return False
53
+
54
+ @property
55
+ def num_vertices(self):
56
+ if self.is_empty:
57
+ return 0
58
+ return self.vertices.shape[0]
59
+
60
+ @property
61
+ def num_edges(self):
62
+ if self.is_empty:
63
+ return 0
64
+ return self.edges.shape[0]
65
+
66
+
67
+ def line_segments(self, threshold = 0.05, device=None, to_np=False):
68
+ is_valid = self.weights>threshold
69
+ p1 = self.vertices[self.edges[is_valid,0]]
70
+ p2 = self.vertices[self.edges[is_valid,1]]
71
+ ps = self.weights[is_valid]
72
+
73
+ lines = torch.cat((p1,p2,ps[:,None]),dim=-1)
74
+ if device is not None:
75
+ lines = lines.to(device)
76
+ if to_np:
77
+ lines = lines.cpu().numpy()
78
+
79
+ return lines
80
+ # if device != self.device:
81
+
82
+ def rescale(self, image_width, image_height):
83
+ scale_x = float(image_width)/float(self.frame_width)
84
+ scale_y = float(image_height)/float(self.frame_height)
85
+
86
+ self.vertices[:,0] *= scale_x
87
+ self.vertices[:,1] *= scale_y
88
+ self.frame_width = image_width
89
+ self.frame_height = image_height
90
+
91
+ def jsonize(self):
92
+ return {
93
+ 'vertices': self.vertices.cpu().tolist(),
94
+ 'vertices-score': self.v_confidences.cpu().tolist(),
95
+ 'edges': self.edges.cpu().tolist(),
96
+ 'edges-weights': self.weights.cpu().tolist(),
97
+ 'height': self.frame_height,
98
+ 'width': self.frame_width,
99
+ }
100
+ def __repr__(self) -> str:
101
+ return "WireframeGraph\n"+\
102
+ "Vertices: {}\n".format(self.num_vertices)+\
103
+ "Edges: {}\n".format(self.num_edges,) + \
104
+ "Frame size (HxW): {}x{}".format(self.frame_height,self.frame_width)
105
+
106
+ #graph = WireframeGraph()
107
+ if __name__ == "__main__":
108
+ graph = WireframeGraph.load_json('NeuS/public_data/bmvs_clock/hawp/000.json')
109
+ print(graph)
110
+
scalelsd/encoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .hafm import HAFMencoder
scalelsd/encoder/hafm.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch.utils.data.dataloader import default_collate
4
+
5
+ from halt import _C
6
+
7
+ class HAFMencoder(object):
8
+ def __init__(self, cfg):
9
+ self.dis_th = cfg.ENCODER.DIS_TH
10
+ self.ang_th = cfg.ENCODER.ANG_TH
11
+ self.num_static_pos_lines = cfg.ENCODER.NUM_STATIC_POS_LINES
12
+ self.num_static_neg_lines = cfg.ENCODER.NUM_STATIC_NEG_LINES
13
+ def __call__(self,annotations):
14
+ targets = []
15
+ metas = []
16
+ for ann in annotations:
17
+ t,m = self._process_per_image(ann)
18
+ targets.append(t)
19
+ metas.append(m)
20
+
21
+ return default_collate(targets),metas
22
+
23
+ def adjacent_matrix(self, n, edges, device):
24
+ mat = torch.zeros(n+1,n+1,dtype=torch.bool,device=device)
25
+ if edges.size(0)>0:
26
+ mat[edges[:,0], edges[:,1]] = 1
27
+ mat[edges[:,1], edges[:,0]] = 1
28
+ return mat
29
+
30
+ def _process_per_image(self,ann):
31
+ junctions = ann['junctions']
32
+ device = junctions.device
33
+ height, width = ann['height'], ann['width']
34
+ jmap = torch.zeros((height,width),device=device)
35
+ joff = torch.zeros((2,height,width),device=device,dtype=torch.float32)
36
+ # junctions[:,0] = junctions[:,0].clamp(min=0,max=width-1)
37
+ # junctions[:,1] = junctions[:,1].clamp(min=0,max=height-1)
38
+ xint,yint = junctions[:,0].long(), junctions[:,1].long()
39
+ off_x = junctions[:,0] - xint.float()-0.5
40
+ off_y = junctions[:,1] - yint.float()-0.5
41
+
42
+ jmap[yint,xint] = 1
43
+ joff[0,yint,xint] = off_x
44
+ joff[1,yint,xint] = off_y
45
+
46
+ edges_positive = ann['edges_positive']
47
+ edges_negative = ann['edges_negative']
48
+
49
+ pos_mat = self.adjacent_matrix(junctions.size(0),edges_positive,device)
50
+ neg_mat = self.adjacent_matrix(junctions.size(0),edges_negative,device)
51
+ lines = torch.cat((junctions[edges_positive[:,0]], junctions[edges_positive[:,1]]),dim=-1)
52
+ lines_neg = torch.cat((junctions[edges_negative[:2000,0]],junctions[edges_negative[:2000,1]]),dim=-1)
53
+ lmap, _, _ = _C.encodels(lines,height,width,height,width,lines.size(0))
54
+
55
+ center_points = (lines[:,:2] + lines[:,2:])/2.0
56
+ cmap = torch.zeros((height,width),device=device)
57
+ cxint, cyint = center_points[:,0].long(), center_points[:,1].long()
58
+ cmap[cyint,cxint] = 1
59
+
60
+ # yy,xx = torch.meshgrid(torch.arange(width,device=device),torch.arange(width,device=device))
61
+ # gaussian = torch.exp(-((yy[:,:,None]-center_points[None,None,:,1])**2 + (xx[:,:,None]-center_points[None,None,:,0])**2)/(2*(2*2)))
62
+ # cmap = gaussian.max(dim=-1)[0]
63
+
64
+ lpos = np.random.permutation(lines.cpu().numpy())[:self.num_static_pos_lines]
65
+ lneg = np.random.permutation(lines_neg.cpu().numpy())[:self.num_static_neg_lines]
66
+ # lpos = lines[torch.randperm(lines.size(0),device=device)][:self.num_static_pos_lines]
67
+ # lneg = lines_neg[torch.randperm(lines_neg.size(0),device=device)][:self.num_static_neg_lines]
68
+ lpos = torch.from_numpy(lpos).to(device)
69
+ lneg = torch.from_numpy(lneg).to(device)
70
+
71
+ lpre = torch.cat((lpos,lneg),dim=0)
72
+ _swap = (torch.rand(lpre.size(0))>0.5).to(device)
73
+ lpre[_swap] = lpre[_swap][:,[2,3,0,1]]
74
+ lpre_label = torch.cat(
75
+ [
76
+ torch.ones(lpos.size(0),device=device),
77
+ torch.zeros(lneg.size(0),device=device)
78
+ ])
79
+
80
+ meta = {
81
+ 'junc': junctions,
82
+ 'Lpos': pos_mat,
83
+ 'Lneg': neg_mat,
84
+ 'lpre': lpre,
85
+ 'lpre_label': lpre_label,
86
+ 'lines': lines,
87
+ }
88
+
89
+
90
+ dismap = torch.sqrt(lmap[0]**2+lmap[1]**2)[None]
91
+ def _normalize(inp):
92
+ mag = torch.sqrt(inp[0]*inp[0]+inp[1]*inp[1])
93
+ return inp/(mag+1e-6)
94
+ md_map = _normalize(lmap[:2])
95
+ st_map = _normalize(lmap[2:4])
96
+ ed_map = _normalize(lmap[4:])
97
+ st_map = lmap[2:4]
98
+ ed_map = lmap[4:]
99
+
100
+ md_ = md_map.reshape(2,-1).t()
101
+ st_ = st_map.reshape(2,-1).t()
102
+ ed_ = ed_map.reshape(2,-1).t()
103
+ Rt = torch.cat(
104
+ (torch.cat((md_[:,None,None,0],md_[:,None,None,1]),dim=2),
105
+ torch.cat((-md_[:,None,None,1], md_[:,None,None,0]),dim=2)),dim=1)
106
+ R = torch.cat(
107
+ (torch.cat((md_[:,None,None,0], -md_[:,None,None,1]),dim=2),
108
+ torch.cat((md_[:,None,None,1], md_[:,None,None,0]),dim=2)),dim=1)
109
+
110
+ Rtst_ = torch.matmul(Rt, st_[:,:,None]).squeeze(-1).t()
111
+ Rted_ = torch.matmul(Rt, ed_[:,:,None]).squeeze(-1).t()
112
+ swap_mask = (Rtst_[1]<0)*(Rted_[1]>0)
113
+ pos_ = Rtst_.clone()
114
+ neg_ = Rted_.clone()
115
+ temp = pos_[:,swap_mask]
116
+ pos_[:,swap_mask] = neg_[:,swap_mask]
117
+ neg_[:,swap_mask] = temp
118
+
119
+ pos_[0] = pos_[0].clamp(min=1e-9)
120
+ pos_[1] = pos_[1].clamp(min=1e-9)
121
+ neg_[0] = neg_[0].clamp(min=1e-9)
122
+ neg_[1] = neg_[1].clamp(max=-1e-9)
123
+
124
+ mask = (dismap.view(-1)<=self.dis_th).float()
125
+
126
+ pos_map = pos_.reshape(-1,height,width)
127
+ neg_map = neg_.reshape(-1,height,width)
128
+
129
+ md_angle = torch.atan2(md_map[1], md_map[0])
130
+ pos_angle = torch.atan2(pos_map[1],pos_map[0])
131
+ neg_angle = torch.atan2(neg_map[1],neg_map[0])
132
+
133
+ mask *= (pos_angle.reshape(-1)>self.ang_th*np.pi/2.0)
134
+ mask *= (neg_angle.reshape(-1)<-self.ang_th*np.pi/2.0)
135
+
136
+ pos_angle_n = pos_angle/(np.pi/2)
137
+ neg_angle_n = -neg_angle/(np.pi/2)
138
+ md_angle_n = md_angle/(np.pi*2) + 0.5
139
+ mask = mask.reshape(height,width)
140
+
141
+
142
+ hafm_ang = torch.cat((md_angle_n[None],pos_angle_n[None],neg_angle_n[None],),dim=0)
143
+ hafm_dis = dismap.clamp(max=self.dis_th)/self.dis_th
144
+ mask = mask[None]
145
+ target = {'jloc':jmap[None],
146
+ 'joff':joff,
147
+ 'cloc': cmap[None],
148
+ 'md': hafm_ang,
149
+ 'dis': hafm_dis,
150
+ 'mask': mask
151
+ }
152
+ return target, meta
scalelsd/ssl/backbones/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .build import build_backbone
scalelsd/ssl/backbones/build.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .dpt.models import DPTFieldModel
2
+
3
+ def build_dpt(
4
+ basemodel = "vitb_rn50_384",
5
+ features=256,
6
+ readout = "project",
7
+ channels_last = False,
8
+ use_bn = True,
9
+ enable_attention_hooks = False,
10
+ head_size = [[3],[1],[1],[2],[2]],
11
+ use_layer_scale = False,
12
+ **kwargs):
13
+
14
+ model = DPTFieldModel(
15
+ features=features,
16
+ backbone=basemodel,
17
+ readout=readout,
18
+ channels_last=channels_last,
19
+ use_bn=use_bn,
20
+ enable_attention_hooks=enable_attention_hooks,
21
+ head_size=head_size,
22
+ use_layer_scale=use_layer_scale
23
+ )
24
+
25
+ return model
26
+
27
+ def build_backbone(**kwargs):
28
+ return build_dpt(**kwargs)
scalelsd/ssl/backbones/dpt/__init__.py ADDED
File without changes
scalelsd/ssl/backbones/dpt/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device("cpu"))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
scalelsd/ssl/backbones/dpt/blocks.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vit import (
5
+ _make_pretrained_vitb_rn50_384,
6
+ _make_pretrained_vitl16_384,
7
+ _make_pretrained_vitb16_384,
8
+ forward_vit,
9
+ )
10
+
11
+
12
+ def _make_encoder(
13
+ backbone,
14
+ features,
15
+ use_pretrained,
16
+ groups=1,
17
+ expand=False,
18
+ exportable=True,
19
+ hooks=None,
20
+ use_vit_only=False,
21
+ use_readout="ignore",
22
+ enable_attention_hooks=False,
23
+ use_layer_scale=False,
24
+ ):
25
+ if backbone == "vitl16_384":
26
+ pretrained = _make_pretrained_vitl16_384(
27
+ use_pretrained,
28
+ hooks=hooks,
29
+ use_readout=use_readout,
30
+ enable_attention_hooks=enable_attention_hooks,
31
+ )
32
+ scratch = _make_scratch(
33
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
34
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
35
+ elif backbone == "vitb_rn50_384":
36
+ pretrained = _make_pretrained_vitb_rn50_384(
37
+ use_pretrained,
38
+ hooks=hooks,
39
+ use_vit_only=use_vit_only,
40
+ use_readout=use_readout,
41
+ enable_attention_hooks=enable_attention_hooks,
42
+ use_layer_scale=use_layer_scale,
43
+ )
44
+ scratch = _make_scratch(
45
+ [256, 512, 768, 768], features, groups=groups, expand=expand
46
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
47
+ elif backbone == "vitb16_384":
48
+ pretrained = _make_pretrained_vitb16_384(
49
+ use_pretrained,
50
+ hooks=hooks,
51
+ use_readout=use_readout,
52
+ enable_attention_hooks=enable_attention_hooks,
53
+ )
54
+ scratch = _make_scratch(
55
+ [96, 192, 384, 768], features, groups=groups, expand=expand
56
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
57
+ elif backbone == "resnext101_wsl":
58
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
59
+ scratch = _make_scratch(
60
+ [256, 512, 1024, 2048], features, groups=groups, expand=expand
61
+ ) # efficientnet_lite3
62
+ else:
63
+ print(f"Backbone '{backbone}' not implemented")
64
+ assert False
65
+
66
+ return pretrained, scratch
67
+
68
+
69
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
70
+ scratch = nn.Module()
71
+
72
+ out_shape1 = out_shape
73
+ out_shape2 = out_shape
74
+ out_shape3 = out_shape
75
+ out_shape4 = out_shape
76
+ if expand == True:
77
+ out_shape1 = out_shape
78
+ out_shape2 = out_shape * 2
79
+ out_shape3 = out_shape * 4
80
+ out_shape4 = out_shape * 8
81
+
82
+ scratch.layer1_rn = nn.Conv2d(
83
+ in_shape[0],
84
+ out_shape1,
85
+ kernel_size=3,
86
+ stride=1,
87
+ padding=1,
88
+ bias=False,
89
+ groups=groups,
90
+ )
91
+ scratch.layer2_rn = nn.Conv2d(
92
+ in_shape[1],
93
+ out_shape2,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1,
97
+ bias=False,
98
+ groups=groups,
99
+ )
100
+ scratch.layer3_rn = nn.Conv2d(
101
+ in_shape[2],
102
+ out_shape3,
103
+ kernel_size=3,
104
+ stride=1,
105
+ padding=1,
106
+ bias=False,
107
+ groups=groups,
108
+ )
109
+ scratch.layer4_rn = nn.Conv2d(
110
+ in_shape[3],
111
+ out_shape4,
112
+ kernel_size=3,
113
+ stride=1,
114
+ padding=1,
115
+ bias=False,
116
+ groups=groups,
117
+ )
118
+
119
+ return scratch
120
+
121
+
122
+ def _make_resnet_backbone(resnet):
123
+ pretrained = nn.Module()
124
+ pretrained.layer1 = nn.Sequential(
125
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
126
+ )
127
+
128
+ pretrained.layer2 = resnet.layer2
129
+ pretrained.layer3 = resnet.layer3
130
+ pretrained.layer4 = resnet.layer4
131
+
132
+ return pretrained
133
+
134
+
135
+ def _make_pretrained_resnext101_wsl(use_pretrained):
136
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
137
+ return _make_resnet_backbone(resnet)
138
+
139
+
140
+ class Interpolate(nn.Module):
141
+ """Interpolation module."""
142
+
143
+ def __init__(self, scale_factor, mode, align_corners=False):
144
+ """Init.
145
+
146
+ Args:
147
+ scale_factor (float): scaling
148
+ mode (str): interpolation mode
149
+ """
150
+ super(Interpolate, self).__init__()
151
+
152
+ self.interp = nn.functional.interpolate
153
+ self.scale_factor = scale_factor
154
+ self.mode = mode
155
+ self.align_corners = align_corners
156
+
157
+ def forward(self, x):
158
+ """Forward pass.
159
+
160
+ Args:
161
+ x (tensor): input
162
+
163
+ Returns:
164
+ tensor: interpolated data
165
+ """
166
+
167
+ x = self.interp(
168
+ x,
169
+ scale_factor=self.scale_factor,
170
+ mode=self.mode,
171
+ align_corners=self.align_corners,
172
+ )
173
+
174
+ # x = self.interp(x, scale_factor=self.scale_factor)
175
+ # x = self.interp(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=True)
176
+
177
+ return x
178
+
179
+
180
+ class ResidualConvUnit(nn.Module):
181
+ """Residual convolution module."""
182
+
183
+ def __init__(self, features):
184
+ """Init.
185
+
186
+ Args:
187
+ features (int): number of features
188
+ """
189
+ super().__init__()
190
+
191
+ self.conv1 = nn.Conv2d(
192
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
193
+ )
194
+
195
+ self.conv2 = nn.Conv2d(
196
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
197
+ )
198
+
199
+ self.relu = nn.ReLU(inplace=True)
200
+
201
+ def forward(self, x):
202
+ """Forward pass.
203
+
204
+ Args:
205
+ x (tensor): input
206
+
207
+ Returns:
208
+ tensor: output
209
+ """
210
+ out = self.relu(x)
211
+ out = self.conv1(out)
212
+ out = self.relu(out)
213
+ out = self.conv2(out)
214
+
215
+ return out + x
216
+
217
+
218
+ class FeatureFusionBlock(nn.Module):
219
+ """Feature fusion block."""
220
+
221
+ def __init__(self, features):
222
+ """Init.
223
+
224
+ Args:
225
+ features (int): number of features
226
+ """
227
+ super(FeatureFusionBlock, self).__init__()
228
+
229
+ self.resConfUnit1 = ResidualConvUnit(features)
230
+ self.resConfUnit2 = ResidualConvUnit(features)
231
+
232
+ def forward(self, *xs):
233
+ """Forward pass.
234
+
235
+ Returns:
236
+ tensor: output
237
+ """
238
+ output = xs[0]
239
+
240
+ if len(xs) == 2:
241
+ output += self.resConfUnit1(xs[1])
242
+
243
+ output = self.resConfUnit2(output)
244
+
245
+ output = nn.functional.interpolate(
246
+ output, scale_factor=2, mode="bilinear", align_corners=True
247
+ )
248
+
249
+ return output
250
+
251
+
252
+ class ResidualConvUnit_custom(nn.Module):
253
+ """Residual convolution module."""
254
+
255
+ def __init__(self, features, activation, bn):
256
+ """Init.
257
+
258
+ Args:
259
+ features (int): number of features
260
+ """
261
+ super().__init__()
262
+
263
+ self.bn = bn
264
+
265
+ self.groups = 1
266
+
267
+ self.conv1 = nn.Conv2d(
268
+ features,
269
+ features,
270
+ kernel_size=3,
271
+ stride=1,
272
+ padding=1,
273
+ bias=not self.bn,
274
+ groups=self.groups,
275
+ )
276
+
277
+ self.conv2 = nn.Conv2d(
278
+ features,
279
+ features,
280
+ kernel_size=3,
281
+ stride=1,
282
+ padding=1,
283
+ bias=not self.bn,
284
+ groups=self.groups,
285
+ )
286
+
287
+ if self.bn == True:
288
+ self.bn1 = nn.BatchNorm2d(features)
289
+ self.bn2 = nn.BatchNorm2d(features)
290
+
291
+ self.activation = activation
292
+
293
+ self.skip_add = nn.quantized.FloatFunctional()
294
+
295
+ def forward(self, x):
296
+ """Forward pass.
297
+
298
+ Args:
299
+ x (tensor): input
300
+
301
+ Returns:
302
+ tensor: output
303
+ """
304
+
305
+ out = self.activation(x)
306
+ out = self.conv1(out)
307
+ if self.bn == True:
308
+ out = self.bn1(out)
309
+
310
+ out = self.activation(out)
311
+ out = self.conv2(out)
312
+ if self.bn == True:
313
+ out = self.bn2(out)
314
+
315
+ if self.groups > 1:
316
+ out = self.conv_merge(out)
317
+
318
+ return self.skip_add.add(out, x)
319
+
320
+ # return out + x
321
+
322
+
323
+ class FeatureFusionBlock_custom(nn.Module):
324
+ """Feature fusion block."""
325
+
326
+ def __init__(
327
+ self,
328
+ features,
329
+ activation,
330
+ deconv=False,
331
+ bn=False,
332
+ expand=False,
333
+ align_corners=True,
334
+ ):
335
+ """Init.
336
+
337
+ Args:
338
+ features (int): number of features
339
+ """
340
+ super(FeatureFusionBlock_custom, self).__init__()
341
+
342
+ self.deconv = deconv
343
+ self.align_corners = align_corners
344
+
345
+ self.groups = 1
346
+
347
+ self.expand = expand
348
+ out_features = features
349
+ if self.expand == True:
350
+ out_features = features // 2
351
+
352
+ self.out_conv = nn.Conv2d(
353
+ features,
354
+ out_features,
355
+ kernel_size=1,
356
+ stride=1,
357
+ padding=0,
358
+ bias=True,
359
+ groups=1,
360
+ )
361
+
362
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
363
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
364
+
365
+ self.skip_add = nn.quantized.FloatFunctional()
366
+
367
+ def forward(self, *xs):
368
+ """Forward pass.
369
+
370
+ Returns:
371
+ tensor: output
372
+ """
373
+ output = xs[0]
374
+
375
+ if len(xs) == 2:
376
+ res = self.resConfUnit1(xs[1])
377
+ output = self.skip_add.add(output, res)
378
+ # output += res
379
+
380
+ output = self.resConfUnit2(output)
381
+
382
+ output = nn.functional.interpolate(
383
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
384
+ )
385
+
386
+ output = self.out_conv(output)
387
+
388
+ return output
scalelsd/ssl/backbones/dpt/midas_net.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet_large(BaseModel):
13
+ """Network for monocular depth estimation."""
14
+
15
+ def __init__(self, path=None, features=256, non_negative=True):
16
+ """Init.
17
+
18
+ Args:
19
+ path (str, optional): Path to saved model. Defaults to None.
20
+ features (int, optional): Number of features. Defaults to 256.
21
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
22
+ """
23
+ print("Loading weights: ", path)
24
+
25
+ super(MidasNet_large, self).__init__()
26
+
27
+ use_pretrained = False if path is None else True
28
+
29
+ self.pretrained, self.scratch = _make_encoder(
30
+ backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained
31
+ )
32
+
33
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
34
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
35
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
36
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
37
+
38
+ self.scratch.output_conv = nn.Sequential(
39
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
40
+ Interpolate(scale_factor=2, mode="bilinear"),
41
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
42
+ nn.ReLU(True),
43
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
44
+ nn.ReLU(True) if non_negative else nn.Identity(),
45
+ )
46
+
47
+ if path:
48
+ self.load(path)
49
+
50
+ def forward(self, x):
51
+ """Forward pass.
52
+
53
+ Args:
54
+ x (tensor): input data (image)
55
+
56
+ Returns:
57
+ tensor: depth
58
+ """
59
+
60
+ layer_1 = self.pretrained.layer1(x)
61
+ layer_2 = self.pretrained.layer2(layer_1)
62
+ layer_3 = self.pretrained.layer3(layer_2)
63
+ layer_4 = self.pretrained.layer4(layer_3)
64
+
65
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
66
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
67
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
68
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
69
+
70
+ path_4 = self.scratch.refinenet4(layer_4_rn)
71
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
72
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
73
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
74
+
75
+ out = self.scratch.output_conv(path_1)
76
+
77
+ return torch.squeeze(out, dim=1)
scalelsd/ssl/backbones/dpt/models.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .base_model import BaseModel
6
+ from .blocks import (
7
+ FeatureFusionBlock,
8
+ FeatureFusionBlock_custom,
9
+ Interpolate,
10
+ _make_encoder,
11
+ forward_vit,
12
+ )
13
+ from ..multi_task_head import MultitaskHead
14
+
15
+
16
+ def _make_fusion_block(features, use_bn):
17
+ return FeatureFusionBlock_custom(
18
+ features,
19
+ nn.ReLU(False),
20
+ deconv=False,
21
+ bn=use_bn,
22
+ expand=False,
23
+ align_corners=True,
24
+ )
25
+
26
+
27
+ class DPT(BaseModel):
28
+ def __init__(
29
+ self,
30
+ head,
31
+ features=256,
32
+ backbone="vitb_rn50_384",
33
+ readout="project",
34
+ channels_last=False,
35
+ use_bn=False,
36
+ enable_attention_hooks=False,
37
+ use_layer_scale=False,
38
+ ):
39
+
40
+ super(DPT, self).__init__()
41
+
42
+ self.channels_last = channels_last
43
+
44
+ hooks = {
45
+ "vitb_rn50_384": [0, 1, 8, 11],
46
+ "vitb16_384": [2, 5, 8, 11],
47
+ "vitl16_384": [5, 11, 17, 23],
48
+ }
49
+
50
+ # Instantiate backbone and reassemble blocks
51
+ self.pretrained, self.scratch = _make_encoder(
52
+ backbone,
53
+ features,
54
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
55
+ groups=1,
56
+ expand=False,
57
+ exportable=False,
58
+ hooks=hooks[backbone],
59
+ use_readout=readout,
60
+ enable_attention_hooks=enable_attention_hooks,
61
+ use_layer_scale=use_layer_scale,
62
+ )
63
+
64
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
65
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
66
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
67
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
68
+
69
+ self.scratch.output_conv = head
70
+
71
+ def forward(self, x):
72
+ if self.channels_last == True:
73
+ x.contiguous(memory_format=torch.channels_last)
74
+
75
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
76
+
77
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
78
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
79
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
80
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
81
+
82
+ path_4 = self.scratch.refinenet4(layer_4_rn)
83
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
84
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
85
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
86
+
87
+ out = self.scratch.output_conv(path_1)
88
+
89
+ return out
90
+
91
+ class DPTFieldModel(DPT):
92
+ def __init__(self, path=None, non_negative=True, head_size=[[3],[1],[1],[2],[2]], **kwargs):
93
+ features = kwargs["features"] if "features" in kwargs else 256
94
+
95
+ kwargs["use_bn"] = True
96
+
97
+ num_class = sum(sum(head_size,[]))
98
+ head = nn.Sequential(
99
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1),
100
+ # nn.BatchNorm2d(features//2),
101
+ nn.ReLU(True),
102
+ MultitaskHead(features//2, num_class, head_size=head_size),
103
+ )
104
+
105
+ super().__init__(head, **kwargs)
106
+
107
+ self.stride = 2
108
+
109
+ def forward(self, x):
110
+ if x.shape[1] == 1:
111
+ x = torch.cat([x,x,x], dim=1)
112
+
113
+ out = super().forward(x)
114
+ return out, None
115
+
scalelsd/ssl/backbones/dpt/transforms.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import math
4
+
5
+
6
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
8
+
9
+ Args:
10
+ sample (dict): sample
11
+ size (tuple): image size
12
+
13
+ Returns:
14
+ tuple: new size
15
+ """
16
+ shape = list(sample["disparity"].shape)
17
+
18
+ if shape[0] >= size[0] and shape[1] >= size[1]:
19
+ return sample
20
+
21
+ scale = [0, 0]
22
+ scale[0] = size[0] / shape[0]
23
+ scale[1] = size[1] / shape[1]
24
+
25
+ scale = max(scale)
26
+
27
+ shape[0] = math.ceil(scale * shape[0])
28
+ shape[1] = math.ceil(scale * shape[1])
29
+
30
+ # resize
31
+ sample["image"] = cv2.resize(
32
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33
+ )
34
+
35
+ sample["disparity"] = cv2.resize(
36
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37
+ )
38
+ sample["mask"] = cv2.resize(
39
+ sample["mask"].astype(np.float32),
40
+ tuple(shape[::-1]),
41
+ interpolation=cv2.INTER_NEAREST,
42
+ )
43
+ sample["mask"] = sample["mask"].astype(bool)
44
+
45
+ return tuple(shape)
46
+
47
+
48
+ class Resize(object):
49
+ """Resize sample to given size (width, height)."""
50
+
51
+ def __init__(
52
+ self,
53
+ width,
54
+ height,
55
+ resize_target=True,
56
+ keep_aspect_ratio=False,
57
+ ensure_multiple_of=1,
58
+ resize_method="lower_bound",
59
+ image_interpolation_method=cv2.INTER_AREA,
60
+ ):
61
+ """Init.
62
+
63
+ Args:
64
+ width (int): desired output width
65
+ height (int): desired output height
66
+ resize_target (bool, optional):
67
+ True: Resize the full sample (image, mask, target).
68
+ False: Resize image only.
69
+ Defaults to True.
70
+ keep_aspect_ratio (bool, optional):
71
+ True: Keep the aspect ratio of the input sample.
72
+ Output sample might not have the given width and height, and
73
+ resize behaviour depends on the parameter 'resize_method'.
74
+ Defaults to False.
75
+ ensure_multiple_of (int, optional):
76
+ Output width and height is constrained to be multiple of this parameter.
77
+ Defaults to 1.
78
+ resize_method (str, optional):
79
+ "lower_bound": Output will be at least as large as the given size.
80
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
81
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
82
+ Defaults to "lower_bound".
83
+ """
84
+ self.__width = width
85
+ self.__height = height
86
+
87
+ self.__resize_target = resize_target
88
+ self.__keep_aspect_ratio = keep_aspect_ratio
89
+ self.__multiple_of = ensure_multiple_of
90
+ self.__resize_method = resize_method
91
+ self.__image_interpolation_method = image_interpolation_method
92
+
93
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
94
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
95
+
96
+ if max_val is not None and y > max_val:
97
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
98
+
99
+ if y < min_val:
100
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
101
+
102
+ return y
103
+
104
+ def get_size(self, width, height):
105
+ # determine new height and width
106
+ scale_height = self.__height / height
107
+ scale_width = self.__width / width
108
+
109
+ if self.__keep_aspect_ratio:
110
+ if self.__resize_method == "lower_bound":
111
+ # scale such that output size is lower bound
112
+ if scale_width > scale_height:
113
+ # fit width
114
+ scale_height = scale_width
115
+ else:
116
+ # fit height
117
+ scale_width = scale_height
118
+ elif self.__resize_method == "upper_bound":
119
+ # scale such that output size is upper bound
120
+ if scale_width < scale_height:
121
+ # fit width
122
+ scale_height = scale_width
123
+ else:
124
+ # fit height
125
+ scale_width = scale_height
126
+ elif self.__resize_method == "minimal":
127
+ # scale as least as possbile
128
+ if abs(1 - scale_width) < abs(1 - scale_height):
129
+ # fit width
130
+ scale_height = scale_width
131
+ else:
132
+ # fit height
133
+ scale_width = scale_height
134
+ else:
135
+ raise ValueError(
136
+ f"resize_method {self.__resize_method} not implemented"
137
+ )
138
+
139
+ if self.__resize_method == "lower_bound":
140
+ new_height = self.constrain_to_multiple_of(
141
+ scale_height * height, min_val=self.__height
142
+ )
143
+ new_width = self.constrain_to_multiple_of(
144
+ scale_width * width, min_val=self.__width
145
+ )
146
+ elif self.__resize_method == "upper_bound":
147
+ new_height = self.constrain_to_multiple_of(
148
+ scale_height * height, max_val=self.__height
149
+ )
150
+ new_width = self.constrain_to_multiple_of(
151
+ scale_width * width, max_val=self.__width
152
+ )
153
+ elif self.__resize_method == "minimal":
154
+ new_height = self.constrain_to_multiple_of(scale_height * height)
155
+ new_width = self.constrain_to_multiple_of(scale_width * width)
156
+ else:
157
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
158
+
159
+ return (new_width, new_height)
160
+
161
+ def __call__(self, sample):
162
+ width, height = self.get_size(
163
+ sample["image"].shape[1], sample["image"].shape[0]
164
+ )
165
+
166
+ # resize sample
167
+ sample["image"] = cv2.resize(
168
+ sample["image"],
169
+ (width, height),
170
+ interpolation=self.__image_interpolation_method,
171
+ )
172
+
173
+ if self.__resize_target:
174
+ if "disparity" in sample:
175
+ sample["disparity"] = cv2.resize(
176
+ sample["disparity"],
177
+ (width, height),
178
+ interpolation=cv2.INTER_NEAREST,
179
+ )
180
+
181
+ if "depth" in sample:
182
+ sample["depth"] = cv2.resize(
183
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
184
+ )
185
+
186
+ sample["mask"] = cv2.resize(
187
+ sample["mask"].astype(np.float32),
188
+ (width, height),
189
+ interpolation=cv2.INTER_NEAREST,
190
+ )
191
+ sample["mask"] = sample["mask"].astype(bool)
192
+
193
+ return sample
194
+
195
+
196
+ class NormalizeImage(object):
197
+ """Normlize image by given mean and std."""
198
+
199
+ def __init__(self, mean, std):
200
+ self.__mean = mean
201
+ self.__std = std
202
+
203
+ def __call__(self, sample):
204
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
205
+
206
+ return sample
207
+
208
+
209
+ class PrepareForNet(object):
210
+ """Prepare sample for usage as network input."""
211
+
212
+ def __init__(self):
213
+ pass
214
+
215
+ def __call__(self, sample):
216
+ image = np.transpose(sample["image"], (2, 0, 1))
217
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
218
+
219
+ if "mask" in sample:
220
+ sample["mask"] = sample["mask"].astype(np.float32)
221
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
222
+
223
+ if "disparity" in sample:
224
+ disparity = sample["disparity"].astype(np.float32)
225
+ sample["disparity"] = np.ascontiguousarray(disparity)
226
+
227
+ if "depth" in sample:
228
+ depth = sample["depth"].astype(np.float32)
229
+ sample["depth"] = np.ascontiguousarray(depth)
230
+
231
+ return sample
scalelsd/ssl/backbones/dpt/vit.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+
9
+ activations = {}
10
+
11
+
12
+ def get_activation(name):
13
+ def hook(model, input, output):
14
+ activations[name] = output
15
+
16
+ return hook
17
+
18
+
19
+ attention = {}
20
+
21
+
22
+ def get_attention(name):
23
+ def hook(module, input, output):
24
+ x = input[0]
25
+ B, N, C = x.shape
26
+ qkv = (
27
+ module.qkv(x)
28
+ .reshape(B, N, 3, module.num_heads, C // module.num_heads)
29
+ .permute(2, 0, 3, 1, 4).contiguous()
30
+ )
31
+ q, k, v = (
32
+ qkv[0],
33
+ qkv[1],
34
+ qkv[2],
35
+ ) # make torchscript happy (cannot use tensor as tuple)
36
+
37
+ attn = (q @ k.transpose(-2, -1).contiguous()) * module.scale
38
+
39
+ attn = attn.softmax(dim=-1) # [:,:,1,1:]
40
+ attention[name] = attn
41
+
42
+ return hook
43
+
44
+
45
+ def get_mean_attention_map(attn, token, shape):
46
+ attn = attn[:, :, token, 1:]
47
+ attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float()
48
+ attn = torch.nn.functional.interpolate(
49
+ attn, size=shape[2:], mode="bicubic", align_corners=False
50
+ ).squeeze(0)
51
+
52
+ all_attn = torch.mean(attn, 0)
53
+
54
+ return all_attn
55
+
56
+
57
+ class Slice(nn.Module):
58
+ def __init__(self, start_index=1):
59
+ super(Slice, self).__init__()
60
+ self.start_index = start_index
61
+
62
+ def forward(self, x):
63
+ return x[:, self.start_index :]
64
+
65
+
66
+ class AddReadout(nn.Module):
67
+ def __init__(self, start_index=1):
68
+ super(AddReadout, self).__init__()
69
+ self.start_index = start_index
70
+
71
+ def forward(self, x):
72
+ if self.start_index == 2:
73
+ readout = (x[:, 0] + x[:, 1]) / 2
74
+ else:
75
+ readout = x[:, 0]
76
+ return x[:, self.start_index :] + readout.unsqueeze(1)
77
+
78
+
79
+ class ProjectReadout(nn.Module):
80
+ def __init__(self, in_features, start_index=1):
81
+ super(ProjectReadout, self).__init__()
82
+ self.start_index = start_index
83
+
84
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
85
+
86
+ def forward(self, x):
87
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
88
+ features = torch.cat((x[:, self.start_index :], readout), -1)
89
+
90
+ return self.project(features)
91
+
92
+
93
+ class Transpose(nn.Module):
94
+ def __init__(self, dim0, dim1):
95
+ super(Transpose, self).__init__()
96
+ self.dim0 = dim0
97
+ self.dim1 = dim1
98
+
99
+ def forward(self, x):
100
+ x = x.transpose(self.dim0, self.dim1).contiguous()
101
+ return x
102
+
103
+
104
+ def forward_vit(pretrained, x):
105
+ b, c, h, w = x.shape
106
+
107
+ glob = pretrained.model.forward_flex(x)
108
+
109
+ layer_1 = pretrained.activations["1"]
110
+ layer_2 = pretrained.activations["2"]
111
+ layer_3 = pretrained.activations["3"]
112
+ layer_4 = pretrained.activations["4"]
113
+
114
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
115
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
116
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
117
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
118
+
119
+ unflatten = nn.Sequential(
120
+ nn.Unflatten(
121
+ 2,
122
+ torch.Size(
123
+ [
124
+ h // pretrained.model.patch_size[1],
125
+ w // pretrained.model.patch_size[0],
126
+ ]
127
+ ),
128
+ )
129
+ )
130
+
131
+ if layer_1.ndim == 3:
132
+ layer_1 = unflatten(layer_1)
133
+ if layer_2.ndim == 3:
134
+ layer_2 = unflatten(layer_2)
135
+ if layer_3.ndim == 3:
136
+ layer_3 = unflatten(layer_3)
137
+ if layer_4.ndim == 3:
138
+ layer_4 = unflatten(layer_4)
139
+
140
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
141
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
142
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
143
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
144
+
145
+ return layer_1, layer_2, layer_3, layer_4
146
+
147
+
148
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
149
+ posemb_tok, posemb_grid = (
150
+ posemb[:, : self.start_index],
151
+ posemb[0, self.start_index :],
152
+ )
153
+
154
+ gs_old = int(math.sqrt(len(posemb_grid)))
155
+
156
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
157
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
158
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
159
+
160
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
161
+
162
+ return posemb
163
+
164
+
165
+ def forward_flex(self, x):
166
+ b, c, h, w = x.shape
167
+
168
+ pos_embed = self._resize_pos_embed(
169
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
170
+ )
171
+
172
+ B = x.shape[0]
173
+
174
+ if hasattr(self.patch_embed, "backbone"):
175
+ x = self.patch_embed.backbone(x)
176
+ if isinstance(x, (list, tuple)):
177
+ x = x[-1] # last feature if backbone outputs list/tuple of features
178
+
179
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2).contiguous()
180
+
181
+ if getattr(self, "dist_token", None) is not None:
182
+ cls_tokens = self.cls_token.expand(
183
+ B, -1, -1
184
+ ) # stole cls_tokens impl from Phil Wang, thanks
185
+ dist_token = self.dist_token.expand(B, -1, -1)
186
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
187
+ else:
188
+ cls_tokens = self.cls_token.expand(
189
+ B, -1, -1
190
+ ) # stole cls_tokens impl from Phil Wang, thanks
191
+ x = torch.cat((cls_tokens, x), dim=1)
192
+
193
+ x = x + pos_embed
194
+ x = self.pos_drop(x)
195
+
196
+ for blk in self.blocks:
197
+ x = blk(x)
198
+
199
+ x = self.norm(x)
200
+
201
+ return x
202
+
203
+
204
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
205
+ if use_readout == "ignore":
206
+ readout_oper = [Slice(start_index)] * len(features)
207
+ elif use_readout == "add":
208
+ readout_oper = [AddReadout(start_index)] * len(features)
209
+ elif use_readout == "project":
210
+ readout_oper = [
211
+ ProjectReadout(vit_features, start_index) for out_feat in features
212
+ ]
213
+ else:
214
+ assert (
215
+ False
216
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
217
+
218
+ return readout_oper
219
+
220
+
221
+ def _make_vit_b16_backbone(
222
+ model,
223
+ features=[96, 192, 384, 768],
224
+ size=[384, 384],
225
+ hooks=[2, 5, 8, 11],
226
+ vit_features=768,
227
+ use_readout="ignore",
228
+ start_index=1,
229
+ enable_attention_hooks=False,
230
+ ):
231
+ pretrained = nn.Module()
232
+
233
+ pretrained.model = model
234
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
235
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
236
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
237
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
238
+
239
+ pretrained.activations = activations
240
+
241
+ if enable_attention_hooks:
242
+ pretrained.model.blocks[hooks[0]].attn.register_forward_hook(
243
+ get_attention("attn_1")
244
+ )
245
+ pretrained.model.blocks[hooks[1]].attn.register_forward_hook(
246
+ get_attention("attn_2")
247
+ )
248
+ pretrained.model.blocks[hooks[2]].attn.register_forward_hook(
249
+ get_attention("attn_3")
250
+ )
251
+ pretrained.model.blocks[hooks[3]].attn.register_forward_hook(
252
+ get_attention("attn_4")
253
+ )
254
+ pretrained.attention = attention
255
+
256
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
257
+
258
+ # 32, 48, 136, 384
259
+ pretrained.act_postprocess1 = nn.Sequential(
260
+ readout_oper[0],
261
+ Transpose(1, 2),
262
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
263
+ nn.Conv2d(
264
+ in_channels=vit_features,
265
+ out_channels=features[0],
266
+ kernel_size=1,
267
+ stride=1,
268
+ padding=0,
269
+ ),
270
+ nn.ConvTranspose2d(
271
+ in_channels=features[0],
272
+ out_channels=features[0],
273
+ kernel_size=4,
274
+ stride=4,
275
+ padding=0,
276
+ bias=True,
277
+ dilation=1,
278
+ groups=1,
279
+ ),
280
+ )
281
+
282
+ pretrained.act_postprocess2 = nn.Sequential(
283
+ readout_oper[1],
284
+ Transpose(1, 2),
285
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
286
+ nn.Conv2d(
287
+ in_channels=vit_features,
288
+ out_channels=features[1],
289
+ kernel_size=1,
290
+ stride=1,
291
+ padding=0,
292
+ ),
293
+ nn.ConvTranspose2d(
294
+ in_channels=features[1],
295
+ out_channels=features[1],
296
+ kernel_size=2,
297
+ stride=2,
298
+ padding=0,
299
+ bias=True,
300
+ dilation=1,
301
+ groups=1,
302
+ ),
303
+ )
304
+
305
+ pretrained.act_postprocess3 = nn.Sequential(
306
+ readout_oper[2],
307
+ Transpose(1, 2),
308
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
309
+ nn.Conv2d(
310
+ in_channels=vit_features,
311
+ out_channels=features[2],
312
+ kernel_size=1,
313
+ stride=1,
314
+ padding=0,
315
+ ),
316
+ )
317
+
318
+ pretrained.act_postprocess4 = nn.Sequential(
319
+ readout_oper[3],
320
+ Transpose(1, 2),
321
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
322
+ nn.Conv2d(
323
+ in_channels=vit_features,
324
+ out_channels=features[3],
325
+ kernel_size=1,
326
+ stride=1,
327
+ padding=0,
328
+ ),
329
+ nn.Conv2d(
330
+ in_channels=features[3],
331
+ out_channels=features[3],
332
+ kernel_size=3,
333
+ stride=2,
334
+ padding=1,
335
+ ),
336
+ )
337
+
338
+ pretrained.model.start_index = start_index
339
+ pretrained.model.patch_size = [16, 16]
340
+
341
+ # We inject this function into the VisionTransformer instances so that
342
+ # we can use it with interpolated position embeddings without modifying the library source.
343
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
344
+ pretrained.model._resize_pos_embed = types.MethodType(
345
+ _resize_pos_embed, pretrained.model
346
+ )
347
+
348
+ return pretrained
349
+
350
+
351
+ def _make_vit_b_rn50_backbone(
352
+ model,
353
+ features=[256, 512, 768, 768],
354
+ size=[384, 384],
355
+ hooks=[0, 1, 8, 11],
356
+ vit_features=768,
357
+ use_vit_only=False,
358
+ use_readout="ignore",
359
+ start_index=1,
360
+ enable_attention_hooks=False,
361
+ use_layer_scale=False,
362
+ ):
363
+ pretrained = nn.Module()
364
+
365
+ ###
366
+ if use_layer_scale:
367
+ from timm.models.vision_transformer import LayerScale
368
+ for i, block in enumerate (model.blocks) :
369
+ block.ls1 = LayerScale(vit_features)
370
+ block.ls2 = LayerScale(vit_features)
371
+
372
+ pretrained.model = model
373
+
374
+ if use_vit_only == True:
375
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
376
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
377
+ else:
378
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
379
+ get_activation("1")
380
+ )
381
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
382
+ get_activation("2")
383
+ )
384
+
385
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
386
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
387
+
388
+ if enable_attention_hooks:
389
+ pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1"))
390
+ pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2"))
391
+ pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3"))
392
+ pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4"))
393
+ pretrained.attention = attention
394
+
395
+ pretrained.activations = activations
396
+
397
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
398
+
399
+ if use_vit_only == True:
400
+ pretrained.act_postprocess1 = nn.Sequential(
401
+ readout_oper[0],
402
+ Transpose(1, 2),
403
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
404
+ nn.Conv2d(
405
+ in_channels=vit_features,
406
+ out_channels=features[0],
407
+ kernel_size=1,
408
+ stride=1,
409
+ padding=0,
410
+ ),
411
+ nn.ConvTranspose2d(
412
+ in_channels=features[0],
413
+ out_channels=features[0],
414
+ kernel_size=4,
415
+ stride=4,
416
+ padding=0,
417
+ bias=True,
418
+ dilation=1,
419
+ groups=1,
420
+ ),
421
+ )
422
+
423
+ pretrained.act_postprocess2 = nn.Sequential(
424
+ readout_oper[1],
425
+ Transpose(1, 2),
426
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
427
+ nn.Conv2d(
428
+ in_channels=vit_features,
429
+ out_channels=features[1],
430
+ kernel_size=1,
431
+ stride=1,
432
+ padding=0,
433
+ ),
434
+ nn.ConvTranspose2d(
435
+ in_channels=features[1],
436
+ out_channels=features[1],
437
+ kernel_size=2,
438
+ stride=2,
439
+ padding=0,
440
+ bias=True,
441
+ dilation=1,
442
+ groups=1,
443
+ ),
444
+ )
445
+ else:
446
+ pretrained.act_postprocess1 = nn.Sequential(
447
+ nn.Identity(), nn.Identity(), nn.Identity()
448
+ )
449
+ pretrained.act_postprocess2 = nn.Sequential(
450
+ nn.Identity(), nn.Identity(), nn.Identity()
451
+ )
452
+
453
+ pretrained.act_postprocess3 = nn.Sequential(
454
+ readout_oper[2],
455
+ Transpose(1, 2),
456
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
457
+ nn.Conv2d(
458
+ in_channels=vit_features,
459
+ out_channels=features[2],
460
+ kernel_size=1,
461
+ stride=1,
462
+ padding=0,
463
+ ),
464
+ )
465
+
466
+ pretrained.act_postprocess4 = nn.Sequential(
467
+ readout_oper[3],
468
+ Transpose(1, 2),
469
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
470
+ nn.Conv2d(
471
+ in_channels=vit_features,
472
+ out_channels=features[3],
473
+ kernel_size=1,
474
+ stride=1,
475
+ padding=0,
476
+ ),
477
+ nn.Conv2d(
478
+ in_channels=features[3],
479
+ out_channels=features[3],
480
+ kernel_size=3,
481
+ stride=2,
482
+ padding=1,
483
+ ),
484
+ )
485
+
486
+ pretrained.model.start_index = start_index
487
+ pretrained.model.patch_size = [16, 16]
488
+
489
+ # We inject this function into the VisionTransformer instances so that
490
+ # we can use it with interpolated position embeddings without modifying the library source.
491
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
492
+
493
+ # We inject this function into the VisionTransformer instances so that
494
+ # we can use it with interpolated position embeddings without modifying the library source.
495
+ pretrained.model._resize_pos_embed = types.MethodType(
496
+ _resize_pos_embed, pretrained.model
497
+ )
498
+
499
+ return pretrained
500
+
501
+
502
+ def _make_pretrained_vitb_rn50_384(
503
+ pretrained,
504
+ use_readout="ignore",
505
+ hooks=None,
506
+ use_vit_only=False,
507
+ enable_attention_hooks=False,
508
+ use_layer_scale=False,
509
+ ):
510
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
511
+
512
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
513
+ return _make_vit_b_rn50_backbone(
514
+ model,
515
+ features=[256, 512, 768, 768],
516
+ size=[384, 384],
517
+ hooks=hooks,
518
+ use_vit_only=use_vit_only,
519
+ use_readout=use_readout,
520
+ enable_attention_hooks=enable_attention_hooks,
521
+ use_layer_scale=use_layer_scale,
522
+ )
523
+
524
+
525
+ def _make_pretrained_vitl16_384(
526
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
527
+ ):
528
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
529
+
530
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
531
+ return _make_vit_b16_backbone(
532
+ model,
533
+ features=[256, 512, 1024, 1024],
534
+ hooks=hooks,
535
+ vit_features=1024,
536
+ use_readout=use_readout,
537
+ enable_attention_hooks=enable_attention_hooks,
538
+ )
539
+
540
+
541
+ def _make_pretrained_vitb16_384(
542
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
543
+ ):
544
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
545
+
546
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
547
+ return _make_vit_b16_backbone(
548
+ model,
549
+ features=[96, 192, 384, 768],
550
+ hooks=hooks,
551
+ use_readout=use_readout,
552
+ enable_attention_hooks=enable_attention_hooks,
553
+ )
554
+
555
+
556
+ def _make_pretrained_deitb16_384(
557
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
558
+ ):
559
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
560
+
561
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
562
+ return _make_vit_b16_backbone(
563
+ model,
564
+ features=[96, 192, 384, 768],
565
+ hooks=hooks,
566
+ use_readout=use_readout,
567
+ enable_attention_hooks=enable_attention_hooks,
568
+ )
569
+
570
+
571
+ def _make_pretrained_deitb16_distil_384(
572
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
573
+ ):
574
+ model = timm.create_model(
575
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
576
+ )
577
+
578
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
579
+ return _make_vit_b16_backbone(
580
+ model,
581
+ features=[96, 192, 384, 768],
582
+ hooks=hooks,
583
+ use_readout=use_readout,
584
+ start_index=2,
585
+ enable_attention_hooks=enable_attention_hooks,
586
+ )
scalelsd/ssl/backbones/multi_task_head.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ class MultitaskHead(nn.Module):
4
+ def __init__(self, input_channels, num_class, head_size):
5
+ super(MultitaskHead, self).__init__()
6
+
7
+ m = int(input_channels / 4)
8
+ heads = []
9
+ for output_channels in sum(head_size, []):
10
+ heads.append(
11
+ nn.Sequential(
12
+ nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
13
+ nn.ReLU(inplace=True),
14
+ nn.Conv2d(m, output_channels, kernel_size=1),
15
+ )
16
+ )
17
+ self.heads = nn.ModuleList(heads)
18
+ assert num_class == sum(sum(head_size, []))
19
+
20
+ def forward(self, x):
21
+ # import pdb;pdb.set_trace()
22
+ return torch.cat([head(x) for head in self.heads], dim=1)
23
+
24
+
25
+ class AngleDistanceHead(nn.Module):
26
+ def __init__(self, input_channels, num_class, head_size):
27
+ super(AngleDistanceHead, self).__init__()
28
+
29
+ m = int(input_channels/4)
30
+
31
+ heads = []
32
+ for output_channels in sum(head_size, []):
33
+ if output_channels != 2:
34
+ heads.append(
35
+ nn.Sequential(
36
+ nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
37
+ nn.ReLU(inplace=True),
38
+ nn.Conv2d(m, output_channels, kernel_size=1),
39
+ )
40
+ )
41
+ else:
42
+ heads.append(
43
+ nn.Sequential(
44
+ nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
45
+ nn.ReLU(inplace=True),
46
+ CosineSineLayer(m)
47
+ )
48
+ )
49
+ self.heads = nn.ModuleList(heads)
50
+ assert num_class == sum(sum(head_size, []))
51
+ def forward(self, x):
52
+ return torch.cat([head(x) for head in self.heads], dim=1)
scalelsd/ssl/config/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .project_config import Config
2
+ from .utils import *
scalelsd/ssl/config/dataset/hpatches_dataset.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### General dataset parameters
2
+ dataset_name: "hpatches"
3
+ add_augmentation_to_all_splits: False
4
+ gray_scale: True
5
+ # Ground truth source ('official' or path to the exported h5 dataset.)
6
+ # gt_source_train: "" # Fill with your own export file
7
+ # gt_source_test: "" # Fill with your own export file
8
+ # Return type: (1) single (to train the detector only)
9
+ # or (2) paired_desc (to train the detector + descriptor)
10
+ return_type: "single"
11
+ random_seed: 0
12
+
13
+ ### Descriptor training parameters
14
+ # Number of points extracted per line
15
+ max_num_samples: 10
16
+ # Max number of training line points extracted in the whole image
17
+ max_pts: 1000
18
+ # Min distance between two points on a line (in pixels)
19
+ min_dist_pts: 10
20
+ # Small jittering of the sampled points during training
21
+ jittering: 0
22
+
23
+ alteration: "all"
24
+ max_side: 1200
25
+
26
+ ### Data preprocessing configuration
27
+ preprocessing:
28
+ resize: [512, 512]
29
+ blur_size: 11
30
+ augmentation:
31
+ random_scaling:
32
+ enable: True
33
+ range: [0.7, 1.5]
34
+ photometric:
35
+ enable: true
36
+ primitives: ['random_brightness', 'random_contrast',
37
+ 'additive_speckle_noise', 'additive_gaussian_noise',
38
+ 'additive_shade', 'motion_blur' ]
39
+ params:
40
+ random_brightness: {brightness: 0.2}
41
+ random_contrast: {contrast: [0.3, 1.5]}
42
+ additive_gaussian_noise: {stddev_range: [0, 10]}
43
+ additive_speckle_noise: {prob_range: [0, 0.0035]}
44
+ additive_shade:
45
+ transparency_range: [-0.5, 0.5]
46
+ kernel_size_range: [100, 150]
47
+ motion_blur: {max_kernel_size: 3}
48
+ random_order: True
49
+ homographic:
50
+ enable: true
51
+ params:
52
+ translation: true
53
+ rotation: true
54
+ scaling: true
55
+ perspective: true
56
+ scaling_amplitude: 0.2
57
+ perspective_amplitude_x: 0.2
58
+ perspective_amplitude_y: 0.2
59
+ patch_ratio: 0.85
60
+ max_angle: 1.57
61
+ allow_artifacts: true
62
+ valid_border_margin: 3
63
+
64
+ ## Homography adaptation configuration
65
+ homography_adaptation:
66
+ num_iter: 10
67
+ valid_border_margin: 3
68
+ min_counts: 3
69
+ homographies:
70
+ translation: true
71
+ rotation: true
72
+ scaling: true
73
+ perspective: true
74
+ scaling_amplitude: 0.2
75
+ perspective_amplitude_x: 0.2
76
+ perspective_amplitude_y: 0.2
77
+ allow_artifacts: true
78
+ patch_ratio: 0.85
79
+
80
+ data:
81
+ name: hpatches
82
+ dataset_dir: HPatches_sequences
83
+ alteration: all
84
+ max_side: 1200
85
+ batch_size: 1
86
+ num_workers: 4
87
+ model:
88
+ name: deeplsd
89
+ tiny: False
90
+ sharpen: True
91
+ line_neighborhood: 5
92
+ loss_weights:
93
+ df: 1.
94
+ angle: 1.
95
+ detect_lines: True
96
+ multiscale: False
97
+ scale_factors: [1., 1.5]
98
+ line_detection_params:
99
+ grad_nfa: True
100
+ merge: False
101
+ optimize: False
102
+ use_vps: False
103
+ optimize_vps: False
104
+ filtering: True
105
+ grad_thresh: 3
scalelsd/ssl/config/dataset/nyu_dataset.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### General dataset parameters
2
+ dataset_name: "nyu"
3
+ add_augmentation_to_all_splits: False
4
+ gray_scale: True
5
+ # Ground truth source ('official' or path to the exported h5 dataset.)
6
+ # gt_source_train: "" # Fill with your own export file
7
+ # gt_source_test: "" # Fill with your own export file
8
+ # Return type: (1) single (to train the detector only)
9
+ # or (2) paired_desc (to train the detector + descriptor)
10
+ return_type: "single"
11
+ random_seed: 0
12
+
13
+ val_size: 49
14
+
15
+ ### Descriptor training parameters
16
+ # Number of points extracted per line
17
+ max_num_samples: 10
18
+ # Max number of training line points extracted in the whole image
19
+ max_pts: 1000
20
+ # Min distance between two points on a line (in pixels)
21
+ min_dist_pts: 10
22
+ # Small jittering of the sampled points during training
23
+ jittering: 0
24
+
25
+ ### Data preprocessing configuration
26
+ preprocessing:
27
+ resize: [512, 512]
28
+ blur_size: 11
29
+ augmentation:
30
+ random_scaling:
31
+ enable: True
32
+ range: [0.7, 1.5]
33
+ photometric:
34
+ enable: true
35
+ primitives: ['random_brightness', 'random_contrast',
36
+ 'additive_speckle_noise', 'additive_gaussian_noise',
37
+ 'additive_shade', 'motion_blur' ]
38
+ params:
39
+ random_brightness: {brightness: 0.2}
40
+ random_contrast: {contrast: [0.3, 1.5]}
41
+ additive_gaussian_noise: {stddev_range: [0, 10]}
42
+ additive_speckle_noise: {prob_range: [0, 0.0035]}
43
+ additive_shade:
44
+ transparency_range: [-0.5, 0.5]
45
+ kernel_size_range: [100, 150]
46
+ motion_blur: {max_kernel_size: 3}
47
+ random_order: True
48
+ homographic:
49
+ enable: true
50
+ params:
51
+ translation: true
52
+ rotation: true
53
+ scaling: true
54
+ perspective: true
55
+ scaling_amplitude: 0.2
56
+ perspective_amplitude_x: 0.2
57
+ perspective_amplitude_y: 0.2
58
+ patch_ratio: 0.85
59
+ max_angle: 1.57
60
+ allow_artifacts: true
61
+ valid_border_margin: 3
62
+
63
+ ## Homography adaptation configuration
64
+ homography_adaptation:
65
+ num_iter: 10
66
+ valid_border_margin: 3
67
+ min_counts: 3
68
+ homographies:
69
+ translation: true
70
+ rotation: true
71
+ scaling: true
72
+ perspective: true
73
+ scaling_amplitude: 0.2
74
+ perspective_amplitude_x: 0.2
75
+ perspective_amplitude_y: 0.2
76
+ allow_artifacts: true
77
+ patch_ratio: 0.85
scalelsd/ssl/config/dataset/official_yorkurban_dataset.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### General dataset parameters
2
+ dataset_name: "official_yorkurban"
3
+ add_augmentation_to_all_splits: False
4
+ gray_scale: True
5
+ # Ground truth source ('official' or path to the exported h5 dataset.)
6
+ # gt_source_train: "" # Fill with your own export file
7
+ # gt_source_test: "" # Fill with your own export file
8
+ # Return type: (1) single (to train the detector only)
9
+ # or (2) paired_desc (to train the detector + descriptor)
10
+ return_type: "single"
11
+ random_seed: 0
12
+
13
+ ### Descriptor training parameters
14
+ # Number of points extracted per line
15
+ max_num_samples: 10
16
+ # Max number of training line points extracted in the whole image
17
+ max_pts: 1000
18
+ # Min distance between two points on a line (in pixels)
19
+ min_dist_pts: 10
20
+ # Small jittering of the sampled points during training
21
+ jittering: 0
22
+
23
+ ### Data preprocessing configuration
24
+ preprocessing:
25
+ resize: [512, 512]
26
+ blur_size: 11
27
+ augmentation:
28
+ random_scaling:
29
+ enable: True
30
+ range: [0.7, 1.5]
31
+ photometric:
32
+ enable: true
33
+ primitives: ['random_brightness', 'random_contrast',
34
+ 'additive_speckle_noise', 'additive_gaussian_noise',
35
+ 'additive_shade', 'motion_blur' ]
36
+ params:
37
+ random_brightness: {brightness: 0.2}
38
+ random_contrast: {contrast: [0.3, 1.5]}
39
+ additive_gaussian_noise: {stddev_range: [0, 10]}
40
+ additive_speckle_noise: {prob_range: [0, 0.0035]}
41
+ additive_shade:
42
+ transparency_range: [-0.5, 0.5]
43
+ kernel_size_range: [100, 150]
44
+ motion_blur: {max_kernel_size: 3}
45
+ random_order: True
46
+ homographic:
47
+ enable: true
48
+ params:
49
+ translation: true
50
+ rotation: true
51
+ scaling: true
52
+ perspective: true
53
+ scaling_amplitude: 0.2
54
+ perspective_amplitude_x: 0.2
55
+ perspective_amplitude_y: 0.2
56
+ patch_ratio: 0.85
57
+ max_angle: 1.57
58
+ allow_artifacts: true
59
+ valid_border_margin: 3
60
+
61
+ ## Homography adaptation configuration
62
+ homography_adaptation:
63
+ num_iter: 10
64
+ valid_border_margin: 3
65
+ min_counts: 3
66
+ homographies:
67
+ translation: true
68
+ rotation: true
69
+ scaling: true
70
+ perspective: true
71
+ scaling_amplitude: 0.2
72
+ perspective_amplitude_x: 0.2
73
+ perspective_amplitude_y: 0.2
74
+ allow_artifacts: true
75
+ patch_ratio: 0.85
scalelsd/ssl/config/dataset/rdnim_dataset.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### General dataset parameters
2
+ dataset_name: "rdnim"
3
+ add_augmentation_to_all_splits: False
4
+ gray_scale: True
5
+ # Ground truth source ('official' or path to the exported h5 dataset.)
6
+ # gt_source_train: "" # Fill with your own export file
7
+ # gt_source_test: "" # Fill with your own export file
8
+ # Return type: (1) single (to train the detector only)
9
+ # or (2) paired_desc (to train the detector + descriptor)
10
+ return_type: "single"
11
+ random_seed: 0
12
+
13
+ ### Descriptor training parameters
14
+ # Number of points extracted per line
15
+ max_num_samples: 10
16
+ # Max number of training line points extracted in the whole image
17
+ max_pts: 1000
18
+ # Min distance between two points on a line (in pixels)
19
+ min_dist_pts: 10
20
+ # Small jittering of the sampled points during training
21
+ jittering: 0
22
+
23
+ reference: "night"
24
+
25
+ ### Data preprocessing configuration
26
+ preprocessing:
27
+ resize: [512, 512]
28
+ blur_size: 11
29
+ augmentation:
30
+ random_scaling:
31
+ enable: True
32
+ range: [0.7, 1.5]
33
+ photometric:
34
+ enable: true
35
+ primitives: ['random_brightness', 'random_contrast',
36
+ 'additive_speckle_noise', 'additive_gaussian_noise',
37
+ 'additive_shade', 'motion_blur' ]
38
+ params:
39
+ random_brightness: {brightness: 0.2}
40
+ random_contrast: {contrast: [0.3, 1.5]}
41
+ additive_gaussian_noise: {stddev_range: [0, 10]}
42
+ additive_speckle_noise: {prob_range: [0, 0.0035]}
43
+ additive_shade:
44
+ transparency_range: [-0.5, 0.5]
45
+ kernel_size_range: [100, 150]
46
+ motion_blur: {max_kernel_size: 3}
47
+ random_order: True
48
+ homographic:
49
+ enable: true
50
+ params:
51
+ translation: true
52
+ rotation: true
53
+ scaling: true
54
+ perspective: true
55
+ scaling_amplitude: 0.2
56
+ perspective_amplitude_x: 0.2
57
+ perspective_amplitude_y: 0.2
58
+ patch_ratio: 0.85
59
+ max_angle: 1.57
60
+ allow_artifacts: true
61
+ valid_border_margin: 3
62
+
63
+ ## Homography adaptation configuration
64
+ homography_adaptation:
65
+ num_iter: 10
66
+ valid_border_margin: 3
67
+ min_counts: 3
68
+ homographies:
69
+ translation: true
70
+ rotation: true
71
+ scaling: true
72
+ perspective: true
73
+ scaling_amplitude: 0.2
74
+ perspective_amplitude_x: 0.2
75
+ perspective_amplitude_y: 0.2
76
+ allow_artifacts: true
77
+ patch_ratio: 0.85
scalelsd/ssl/config/dataset/synthetic_dataset-1024.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### General dataset parameters
2
+ dataset_name: "synthetic_shape"
3
+ primitives: "all"
4
+ add_augmentation_to_all_splits: True
5
+ test_augmentation_seed: 200
6
+ # Shape generation configuration
7
+ generation:
8
+ # split_sizes: {'train': 20000, 'val': 2000, 'test': 400}
9
+ split_sizes: {'train': 2000, 'val': 2000, 'test': 400}
10
+ random_seed: 10
11
+ image_size: [960, 1280]
12
+ min_len: 0.0985
13
+ min_label_len: 0.099
14
+ params:
15
+ generate_background:
16
+ min_kernel_size: 150
17
+ max_kernel_size: 500
18
+ min_rad_ratio: 0.02
19
+ max_rad_ratio: 0.031
20
+ draw_stripes:
21
+ transform_params: [0.1, 0.1]
22
+ draw_multiple_polygons:
23
+ kernel_boundaries: [50, 100]
24
+
25
+ ### Data preprocessing configuration.
26
+ preprocessing:
27
+ resize: [1024, 1024]
28
+ blur_size: 11
29
+ augmentation:
30
+ photometric:
31
+ enable: True
32
+ primitives: 'all'
33
+ params: {}
34
+ random_order: True
35
+ homographic:
36
+ enable: True
37
+ params:
38
+ translation: true
39
+ rotation: true
40
+ scaling: true
41
+ perspective: true
42
+ scaling_amplitude: 0.2
43
+ perspective_amplitude_x: 0.2
44
+ perspective_amplitude_y: 0.2
45
+ patch_ratio: 0.8
46
+ max_angle: 1.57
47
+ allow_artifacts: true
48
+ translation_overflow: 0.05
49
+ valid_border_margin: 0
scalelsd/ssl/config/dataset/synthetic_dataset-2k.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### General dataset parameters
2
+ dataset_name: "synthetic_shape"
3
+ primitives: "all"
4
+ add_augmentation_to_all_splits: True
5
+ test_augmentation_seed: 200
6
+ alias: 2k
7
+ # Shape generation configuration
8
+ generation:
9
+ # split_sizes: {'train': 20000, 'val': 2000, 'test': 400}
10
+ split_sizes: {'train': 2000, 'val': 200, 'test': 400}
11
+ random_seed: 10
12
+ image_size: [960, 1280]
13
+ min_len: 0.0985
14
+ min_label_len: 0.099
15
+ params:
16
+ generate_background:
17
+ min_kernel_size: 150
18
+ max_kernel_size: 500
19
+ min_rad_ratio: 0.02
20
+ max_rad_ratio: 0.031
21
+ draw_stripes:
22
+ transform_params: [0.1, 0.1]
23
+ draw_multiple_polygons:
24
+ kernel_boundaries: [50, 100]
25
+
26
+ ### Data preprocessing configuration.
27
+ preprocessing:
28
+ resize: [512, 512]
29
+ blur_size: 11
30
+ augmentation:
31
+ photometric:
32
+ enable: True
33
+ primitives: 'all'
34
+ params: {}
35
+ random_order: True
36
+ homographic:
37
+ enable: True
38
+ params:
39
+ translation: true
40
+ rotation: true
41
+ scaling: true
42
+ perspective: true
43
+ scaling_amplitude: 0.2
44
+ perspective_amplitude_x: 0.2
45
+ perspective_amplitude_y: 0.2
46
+ patch_ratio: 0.8
47
+ max_angle: 1.57
48
+ allow_artifacts: true
49
+ translation_overflow: 0.05
50
+ valid_border_margin: 0
scalelsd/ssl/config/dataset/synthetic_dataset-4k.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### General dataset parameters
2
+ dataset_name: "synthetic_shape"
3
+ primitives: "all"
4
+ add_augmentation_to_all_splits: True
5
+ test_augmentation_seed: 200
6
+ alias: 4k
7
+ # Shape generation configuration
8
+ generation:
9
+ # split_sizes: {'train': 20000, 'val': 2000, 'test': 400}
10
+ split_sizes: {'train': 4000, 'val': 2000, 'test': 400}
11
+ random_seed: 10
12
+ image_size: [960, 1280]
13
+ min_len: 0.0985
14
+ min_label_len: 0.099
15
+ params:
16
+ generate_background:
17
+ min_kernel_size: 150
18
+ max_kernel_size: 500
19
+ min_rad_ratio: 0.02
20
+ max_rad_ratio: 0.031
21
+ draw_stripes:
22
+ transform_params: [0.1, 0.1]
23
+ draw_multiple_polygons:
24
+ kernel_boundaries: [50, 100]
25
+
26
+ ### Data preprocessing configuration.
27
+ preprocessing:
28
+ resize: [512, 512]
29
+ blur_size: 11
30
+ augmentation:
31
+ photometric:
32
+ enable: True
33
+ primitives: 'all'
34
+ params: {}
35
+ random_order: True
36
+ homographic:
37
+ enable: True
38
+ params:
39
+ translation: true
40
+ rotation: true
41
+ scaling: true
42
+ perspective: true
43
+ scaling_amplitude: 0.2
44
+ perspective_amplitude_x: 0.2
45
+ perspective_amplitude_y: 0.2
46
+ patch_ratio: 0.8
47
+ max_angle: 1.57
48
+ allow_artifacts: true
49
+ translation_overflow: 0.05
50
+ valid_border_margin: 0
scalelsd/ssl/config/dataset/synthetic_dataset-large.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### General dataset parameters
2
+ dataset_name: "synthetic_shape"
3
+ primitives: "all"
4
+ add_augmentation_to_all_splits: True
5
+ test_augmentation_seed: 200
6
+ alias: "synthetic_shape_large"
7
+ # Shape generation configuration
8
+ generation:
9
+ split_sizes: {'train': 20000, 'val': 2000, 'test': 400}
10
+ # split_sizes: {'train': 2000, 'val': 2000, 'test': 400}
11
+ random_seed: 10
12
+ image_size: [960, 1280]
13
+ min_len: 0.0985
14
+ min_label_len: 0.099
15
+ params:
16
+ generate_background:
17
+ min_kernel_size: 150
18
+ max_kernel_size: 500
19
+ min_rad_ratio: 0.02
20
+ max_rad_ratio: 0.031
21
+ draw_stripes:
22
+ transform_params: [0.1, 0.1]
23
+ draw_multiple_polygons:
24
+ kernel_boundaries: [50, 100]
25
+
26
+ ### Data preprocessing configuration.
27
+ preprocessing:
28
+ resize: [512, 512]
29
+ blur_size: 11
30
+ augmentation:
31
+ photometric:
32
+ enable: True
33
+ primitives: 'all'
34
+ params: {}
35
+ random_order: True
36
+ homographic:
37
+ enable: True
38
+ params:
39
+ translation: true
40
+ rotation: true
41
+ scaling: true
42
+ perspective: true
43
+ scaling_amplitude: 0.2
44
+ perspective_amplitude_x: 0.2
45
+ perspective_amplitude_y: 0.2
46
+ patch_ratio: 0.8
47
+ max_angle: 1.57
48
+ allow_artifacts: true
49
+ translation_overflow: 0.05
50
+ valid_border_margin: 0
scalelsd/ssl/config/dataset/synthetic_dataset.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### General dataset parameters
2
+ dataset_name: "synthetic_shape"
3
+ primitives: "all"
4
+ add_augmentation_to_all_splits: True
5
+ test_augmentation_seed: 200
6
+ # Shape generation configuration
7
+ generation:
8
+ # split_sizes: {'train': 20000, 'val': 2000, 'test': 400}
9
+ # split_sizes: {'train': 2000, 'val': 2000, 'test': 400}
10
+ split_sizes: {'train': 100, 'val': 100, 'test': 100}
11
+ random_seed: 10
12
+ # image_size: [960, 1280]
13
+ image_size: [1024, 1024]
14
+ min_len: 0.0985
15
+ min_label_len: 0.099
16
+ params:
17
+ generate_background:
18
+ min_kernel_size: 150
19
+ max_kernel_size: 500
20
+ min_rad_ratio: 0.02
21
+ max_rad_ratio: 0.031
22
+ draw_stripes:
23
+ transform_params: [0.1, 0.1]
24
+ draw_multiple_polygons:
25
+ kernel_boundaries: [50, 100]
26
+
27
+ ### Data preprocessing configuration.
28
+ preprocessing:
29
+ resize: [512, 512]
30
+ blur_size: 11
31
+ augmentation:
32
+ photometric:
33
+ enable: True
34
+ primitives: 'all'
35
+ params: {}
36
+ random_order: True
37
+ homographic:
38
+ enable: True
39
+ params:
40
+ translation: true
41
+ rotation: true
42
+ scaling: true
43
+ perspective: true
44
+ scaling_amplitude: 0.2
45
+ perspective_amplitude_x: 0.2
46
+ perspective_amplitude_y: 0.2
47
+ patch_ratio: 0.8
48
+ max_angle: 1.57
49
+ allow_artifacts: true
50
+ translation_overflow: 0.05
51
+ valid_border_margin: 0
scalelsd/ssl/config/dataset/wireframe_official_gt copy.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_name: "wireframe"
2
+ add_augmentation_to_all_splits: False
3
+ gray_scale: True
4
+ # return_type: "paired_desc"
5
+ random_seed: 0
6
+ # Ground truth source (official or path to the epxorted h5 dataset.)
7
+ gt_source_train: "official"
8
+ gt_source_test: "official"
9
+ # Date preprocessing configuration.
10
+ preprocessing:
11
+ resize: [512, 512]
12
+ blur_size: 11
13
+ augmentation:
14
+ random_scaling:
15
+ enable: True
16
+ range: [0.7, 1.5]
17
+ photometric:
18
+ enable: true
19
+ primitives: ['random_brightness', 'random_contrast',
20
+ 'additive_speckle_noise', 'additive_gaussian_noise',
21
+ 'additive_shade', 'motion_blur' ]
22
+ params:
23
+ random_brightness: {brightness: 0.2}
24
+ random_contrast: {contrast: [0.3, 1.5]}
25
+ additive_gaussian_noise: {stddev_range: [0, 10]}
26
+ additive_speckle_noise: {prob_range: [0, 0.0035]}
27
+ additive_shade:
28
+ transparency_range: [-0.5, 0.5]
29
+ kernel_size_range: [100, 150]
30
+ motion_blur: {max_kernel_size: 3}
31
+ random_order: True
32
+ homographic:
33
+ enable: true
34
+ params:
35
+ translation: true
36
+ rotation: true
37
+ scaling: true
38
+ perspective: true
39
+ scaling_amplitude: 0.2
40
+ perspective_amplitude_x: 0.2
41
+ perspective_amplitude_y: 0.2
42
+ patch_ratio: 0.85
43
+ max_angle: 1.57
44
+ allow_artifacts: true
45
+ valid_border_margin: 3
46
+ # The homography adaptation configuration
47
+ homography_adaptation:
48
+ num_iter: 100
49
+ aggregation: 'sum'
50
+ mode: 'ver1'
51
+ valid_border_margin: 3
52
+ min_counts: 30
53
+ homographies:
54
+ translation: true
55
+ rotation: true
56
+ scaling: true
57
+ perspective: true
58
+ scaling_amplitude: 0.2
59
+ perspective_amplitude_x: 0.2
60
+ perspective_amplitude_y: 0.2
61
+ allow_artifacts: true
62
+ patch_ratio: 0.85
63
+ # Evaluation related config
64
+ evaluation:
65
+ repeatability:
66
+ # Initial random seed used to sample homographic augmentation
67
+ seed: 200
68
+ # Parameter used to sample illumination change evaluation set.
69
+ photometric:
70
+ enable: False
71
+ # Parameter used to sample viewpoint change evaluation set.
72
+ homographic:
73
+ enable: True
74
+ num_samples: 2
75
+ params:
76
+ translation: true
77
+ rotation: true
78
+ scaling: true
79
+ perspective: true
80
+ scaling_amplitude: 0.2
81
+ perspective_amplitude_x: 0.2
82
+ perspective_amplitude_y: 0.2
83
+ patch_ratio: 0.85
84
+ max_angle: 1.57
85
+ allow_artifacts: true
86
+ valid_border_margin: 3
scalelsd/ssl/config/dataset/wireframe_official_gt.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_name: "wireframe"
2
+ add_augmentation_to_all_splits: False
3
+ gray_scale: True
4
+ # return_type: "paired_desc"
5
+ random_seed: 0
6
+ # Ground truth source (official or path to the epxorted h5 dataset.)
7
+ gt_source_train: "official"
8
+ gt_source_test: "official"
9
+ # Date preprocessing configuration.
10
+ preprocessing:
11
+ resize: [512, 512]
12
+ blur_size: 11
13
+ augmentation:
14
+ random_scaling:
15
+ enable: True
16
+ range: [0.7, 1.5]
17
+ photometric:
18
+ enable: true
19
+ primitives: ['random_brightness', 'random_contrast',
20
+ 'additive_speckle_noise', 'additive_gaussian_noise',
21
+ 'additive_shade', 'motion_blur' ]
22
+ params:
23
+ random_brightness: {brightness: 0.2}
24
+ random_contrast: {contrast: [0.3, 1.5]}
25
+ additive_gaussian_noise: {stddev_range: [0, 10]}
26
+ additive_speckle_noise: {prob_range: [0, 0.0035]}
27
+ additive_shade:
28
+ transparency_range: [-0.5, 0.5]
29
+ kernel_size_range: [100, 150]
30
+ motion_blur: {max_kernel_size: 3}
31
+ random_order: True
32
+ homographic:
33
+ enable: true
34
+ params:
35
+ translation: true
36
+ rotation: true
37
+ scaling: true
38
+ perspective: true
39
+ scaling_amplitude: 0.2
40
+ perspective_amplitude_x: 0.2
41
+ perspective_amplitude_y: 0.2
42
+ patch_ratio: 0.85
43
+ max_angle: 1.57
44
+ allow_artifacts: true
45
+ valid_border_margin: 3
46
+ # The homography adaptation configuration
47
+ homography_adaptation:
48
+ num_iter: 100
49
+ aggregation: 'sum'
50
+ mode: 'ver1'
51
+ valid_border_margin: 3
52
+ min_counts: 30
53
+ homographies:
54
+ translation: true
55
+ rotation: true
56
+ scaling: true
57
+ perspective: true
58
+ scaling_amplitude: 0.2
59
+ perspective_amplitude_x: 0.2
60
+ perspective_amplitude_y: 0.2
61
+ allow_artifacts: true
62
+ patch_ratio: 0.85
63
+ # Evaluation related config
64
+ evaluation:
65
+ repeatability:
66
+ # Initial random seed used to sample homographic augmentation
67
+ seed: 200
68
+ # Parameter used to sample illumination change evaluation set.
69
+ photometric:
70
+ enable: False
71
+ # Parameter used to sample viewpoint change evaluation set.
72
+ homographic:
73
+ enable: True
74
+ num_samples: 2
75
+ params:
76
+ translation: true
77
+ rotation: true
78
+ scaling: true
79
+ perspective: true
80
+ scaling_amplitude: 0.2
81
+ perspective_amplitude_x: 0.2
82
+ perspective_amplitude_y: 0.2
83
+ patch_ratio: 0.85
84
+ max_angle: 1.57
85
+ allow_artifacts: true
86
+ valid_border_margin: 3