Philippe Weinzaepfel
commited on
Commit
·
3ef85e9
0
Parent(s):
huggingface demo
Browse files- .gitattributes +27 -0
- LICENSE +446 -0
- NOTICE +46 -0
- README.md +163 -0
- app.py +77 -0
- checkpoints/PUMP-stytrf.pt +3 -0
- checkpoints/PUMP.pt +3 -0
- core/conv_mixer.py +87 -0
- core/cuda_deepm/.gitignore +4 -0
- core/cuda_deepm/__init__.py +9 -0
- core/cuda_deepm/func.cpp +215 -0
- core/cuda_deepm/kernels.cu +578 -0
- core/cuda_deepm/setup.py +24 -0
- core/functional.py +440 -0
- core/losses/__init__.py +8 -0
- core/losses/ap_loss.py +61 -0
- core/losses/ap_loss_sampler.py +131 -0
- core/losses/multiloss.py +57 -0
- core/losses/pixel_ap_loss.py +82 -0
- core/losses/unsupervised_deepmatching_loss.py +146 -0
- core/pixel_desc.py +60 -0
- datasets/__init__.py +9 -0
- datasets/demo_warp/mountains_src.jpg +0 -0
- datasets/demo_warp/mountains_tgt.jpg +0 -0
- datasets/image_set.py +91 -0
- datasets/pair_dataset.py +226 -0
- datasets/pair_loader.py +291 -0
- datasets/sfm120k.py +27 -0
- datasets/transforms.py +540 -0
- datasets/transforms_tools.py +71 -0
- datasets/utils.py +104 -0
- datasets/web_images.py +50 -0
- demo_warping.py +102 -0
- download_training_data.sh +53 -0
- imgs/demo_warp.jpg +0 -0
- imgs/overview.png +0 -0
- imgs/teaser_paper.jpg +0 -0
- imgs/test.png +0 -0
- post_filter.py +235 -0
- requirements.txt +5 -0
- run_ETH3D.py +118 -0
- test_multiscale.py +262 -0
- test_multiscale_recursive.py +24 -0
- test_singlescale.py +284 -0
- test_singlescale_recursive.py +156 -0
- tools/common.py +95 -0
- tools/trainer.py +125 -0
- tools/viz.py +266 -0
- train.py +121 -0
.gitattributes
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
19 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PUMP
|
2 |
+
Copyright (c) 2022-present NAVER Corp.
|
3 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
|
4 |
+
|
5 |
+
A summary of the CC BY-NC-SA 4.0 license is located here:
|
6 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
Attribution-NonCommercial-ShareAlike 4.0 International
|
11 |
+
|
12 |
+
=======================================================================
|
13 |
+
|
14 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
15 |
+
does not provide legal services or legal advice. Distribution of
|
16 |
+
Creative Commons public licenses does not create a lawyer-client or
|
17 |
+
other relationship. Creative Commons makes its licenses and related
|
18 |
+
information available on an "as-is" basis. Creative Commons gives no
|
19 |
+
warranties regarding its licenses, any material licensed under their
|
20 |
+
terms and conditions, or any related information. Creative Commons
|
21 |
+
disclaims all liability for damages resulting from their use to the
|
22 |
+
fullest extent possible.
|
23 |
+
|
24 |
+
Using Creative Commons Public Licenses
|
25 |
+
|
26 |
+
Creative Commons public licenses provide a standard set of terms and
|
27 |
+
conditions that creators and other rights holders may use to share
|
28 |
+
original works of authorship and other material subject to copyright
|
29 |
+
and certain other rights specified in the public license below. The
|
30 |
+
following considerations are for informational purposes only, are not
|
31 |
+
exhaustive, and do not form part of our licenses.
|
32 |
+
|
33 |
+
Considerations for licensors: Our public licenses are
|
34 |
+
intended for use by those authorized to give the public
|
35 |
+
permission to use material in ways otherwise restricted by
|
36 |
+
copyright and certain other rights. Our licenses are
|
37 |
+
irrevocable. Licensors should read and understand the terms
|
38 |
+
and conditions of the license they choose before applying it.
|
39 |
+
Licensors should also secure all rights necessary before
|
40 |
+
applying our licenses so that the public can reuse the
|
41 |
+
material as expected. Licensors should clearly mark any
|
42 |
+
material not subject to the license. This includes other CC-
|
43 |
+
licensed material, or material used under an exception or
|
44 |
+
limitation to copyright. More considerations for licensors:
|
45 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
46 |
+
|
47 |
+
Considerations for the public: By using one of our public
|
48 |
+
licenses, a licensor grants the public permission to use the
|
49 |
+
licensed material under specified terms and conditions. If
|
50 |
+
the licensor's permission is not necessary for any reason--for
|
51 |
+
example, because of any applicable exception or limitation to
|
52 |
+
copyright--then that use is not regulated by the license. Our
|
53 |
+
licenses grant only permissions under copyright and certain
|
54 |
+
other rights that a licensor has authority to grant. Use of
|
55 |
+
the licensed material may still be restricted for other
|
56 |
+
reasons, including because others have copyright or other
|
57 |
+
rights in the material. A licensor may make special requests,
|
58 |
+
such as asking that all changes be marked or described.
|
59 |
+
Although not required by our licenses, you are encouraged to
|
60 |
+
respect those requests where reasonable. More considerations
|
61 |
+
for the public:
|
62 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
63 |
+
|
64 |
+
=======================================================================
|
65 |
+
|
66 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
67 |
+
Public License
|
68 |
+
|
69 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
70 |
+
to be bound by the terms and conditions of this Creative Commons
|
71 |
+
Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
72 |
+
("Public License"). To the extent this Public License may be
|
73 |
+
interpreted as a contract, You are granted the Licensed Rights in
|
74 |
+
consideration of Your acceptance of these terms and conditions, and the
|
75 |
+
Licensor grants You such rights in consideration of benefits the
|
76 |
+
Licensor receives from making the Licensed Material available under
|
77 |
+
these terms and conditions.
|
78 |
+
|
79 |
+
|
80 |
+
Section 1 -- Definitions.
|
81 |
+
|
82 |
+
a. Adapted Material means material subject to Copyright and Similar
|
83 |
+
Rights that is derived from or based upon the Licensed Material
|
84 |
+
and in which the Licensed Material is translated, altered,
|
85 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
86 |
+
permission under the Copyright and Similar Rights held by the
|
87 |
+
Licensor. For purposes of this Public License, where the Licensed
|
88 |
+
Material is a musical work, performance, or sound recording,
|
89 |
+
Adapted Material is always produced where the Licensed Material is
|
90 |
+
synched in timed relation with a moving image.
|
91 |
+
|
92 |
+
b. Adapter's License means the license You apply to Your Copyright
|
93 |
+
and Similar Rights in Your contributions to Adapted Material in
|
94 |
+
accordance with the terms and conditions of this Public License.
|
95 |
+
|
96 |
+
c. BY-NC-SA Compatible License means a license listed at
|
97 |
+
creativecommons.org/compatiblelicenses, approved by Creative
|
98 |
+
Commons as essentially the equivalent of this Public License.
|
99 |
+
|
100 |
+
d. Copyright and Similar Rights means copyright and/or similar rights
|
101 |
+
closely related to copyright including, without limitation,
|
102 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
103 |
+
Rights, without regard to how the rights are labeled or
|
104 |
+
categorized. For purposes of this Public License, the rights
|
105 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
106 |
+
Rights.
|
107 |
+
|
108 |
+
e. Effective Technological Measures means those measures that, in the
|
109 |
+
absence of proper authority, may not be circumvented under laws
|
110 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
111 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
112 |
+
agreements.
|
113 |
+
|
114 |
+
f. Exceptions and Limitations means fair use, fair dealing, and/or
|
115 |
+
any other exception or limitation to Copyright and Similar Rights
|
116 |
+
that applies to Your use of the Licensed Material.
|
117 |
+
|
118 |
+
g. License Elements means the license attributes listed in the name
|
119 |
+
of a Creative Commons Public License. The License Elements of this
|
120 |
+
Public License are Attribution, NonCommercial, and ShareAlike.
|
121 |
+
|
122 |
+
h. Licensed Material means the artistic or literary work, database,
|
123 |
+
or other material to which the Licensor applied this Public
|
124 |
+
License.
|
125 |
+
|
126 |
+
i. Licensed Rights means the rights granted to You subject to the
|
127 |
+
terms and conditions of this Public License, which are limited to
|
128 |
+
all Copyright and Similar Rights that apply to Your use of the
|
129 |
+
Licensed Material and that the Licensor has authority to license.
|
130 |
+
|
131 |
+
j. Licensor means the individual(s) or entity(ies) granting rights
|
132 |
+
under this Public License.
|
133 |
+
|
134 |
+
k. NonCommercial means not primarily intended for or directed towards
|
135 |
+
commercial advantage or monetary compensation. For purposes of
|
136 |
+
this Public License, the exchange of the Licensed Material for
|
137 |
+
other material subject to Copyright and Similar Rights by digital
|
138 |
+
file-sharing or similar means is NonCommercial provided there is
|
139 |
+
no payment of monetary compensation in connection with the
|
140 |
+
exchange.
|
141 |
+
|
142 |
+
l. Share means to provide material to the public by any means or
|
143 |
+
process that requires permission under the Licensed Rights, such
|
144 |
+
as reproduction, public display, public performance, distribution,
|
145 |
+
dissemination, communication, or importation, and to make material
|
146 |
+
available to the public including in ways that members of the
|
147 |
+
public may access the material from a place and at a time
|
148 |
+
individually chosen by them.
|
149 |
+
|
150 |
+
m. Sui Generis Database Rights means rights other than copyright
|
151 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
152 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
153 |
+
as amended and/or succeeded, as well as other essentially
|
154 |
+
equivalent rights anywhere in the world.
|
155 |
+
|
156 |
+
n. You means the individual or entity exercising the Licensed Rights
|
157 |
+
under this Public License. Your has a corresponding meaning.
|
158 |
+
|
159 |
+
|
160 |
+
Section 2 -- Scope.
|
161 |
+
|
162 |
+
a. License grant.
|
163 |
+
|
164 |
+
1. Subject to the terms and conditions of this Public License,
|
165 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
166 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
167 |
+
exercise the Licensed Rights in the Licensed Material to:
|
168 |
+
|
169 |
+
a. reproduce and Share the Licensed Material, in whole or
|
170 |
+
in part, for NonCommercial purposes only; and
|
171 |
+
|
172 |
+
b. produce, reproduce, and Share Adapted Material for
|
173 |
+
NonCommercial purposes only.
|
174 |
+
|
175 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
176 |
+
Exceptions and Limitations apply to Your use, this Public
|
177 |
+
License does not apply, and You do not need to comply with
|
178 |
+
its terms and conditions.
|
179 |
+
|
180 |
+
3. Term. The term of this Public License is specified in Section
|
181 |
+
6(a).
|
182 |
+
|
183 |
+
4. Media and formats; technical modifications allowed. The
|
184 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
185 |
+
all media and formats whether now known or hereafter created,
|
186 |
+
and to make technical modifications necessary to do so. The
|
187 |
+
Licensor waives and/or agrees not to assert any right or
|
188 |
+
authority to forbid You from making technical modifications
|
189 |
+
necessary to exercise the Licensed Rights, including
|
190 |
+
technical modifications necessary to circumvent Effective
|
191 |
+
Technological Measures. For purposes of this Public License,
|
192 |
+
simply making modifications authorized by this Section 2(a)
|
193 |
+
(4) never produces Adapted Material.
|
194 |
+
|
195 |
+
5. Downstream recipients.
|
196 |
+
|
197 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
198 |
+
recipient of the Licensed Material automatically
|
199 |
+
receives an offer from the Licensor to exercise the
|
200 |
+
Licensed Rights under the terms and conditions of this
|
201 |
+
Public License.
|
202 |
+
|
203 |
+
b. Additional offer from the Licensor -- Adapted Material.
|
204 |
+
Every recipient of Adapted Material from You
|
205 |
+
automatically receives an offer from the Licensor to
|
206 |
+
exercise the Licensed Rights in the Adapted Material
|
207 |
+
under the conditions of the Adapter's License You apply.
|
208 |
+
|
209 |
+
c. No downstream restrictions. You may not offer or impose
|
210 |
+
any additional or different terms or conditions on, or
|
211 |
+
apply any Effective Technological Measures to, the
|
212 |
+
Licensed Material if doing so restricts exercise of the
|
213 |
+
Licensed Rights by any recipient of the Licensed
|
214 |
+
Material.
|
215 |
+
|
216 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
217 |
+
may be construed as permission to assert or imply that You
|
218 |
+
are, or that Your use of the Licensed Material is, connected
|
219 |
+
with, or sponsored, endorsed, or granted official status by,
|
220 |
+
the Licensor or others designated to receive attribution as
|
221 |
+
provided in Section 3(a)(1)(A)(i).
|
222 |
+
|
223 |
+
b. Other rights.
|
224 |
+
|
225 |
+
1. Moral rights, such as the right of integrity, are not
|
226 |
+
licensed under this Public License, nor are publicity,
|
227 |
+
privacy, and/or other similar personality rights; however, to
|
228 |
+
the extent possible, the Licensor waives and/or agrees not to
|
229 |
+
assert any such rights held by the Licensor to the limited
|
230 |
+
extent necessary to allow You to exercise the Licensed
|
231 |
+
Rights, but not otherwise.
|
232 |
+
|
233 |
+
2. Patent and trademark rights are not licensed under this
|
234 |
+
Public License.
|
235 |
+
|
236 |
+
3. To the extent possible, the Licensor waives any right to
|
237 |
+
collect royalties from You for the exercise of the Licensed
|
238 |
+
Rights, whether directly or through a collecting society
|
239 |
+
under any voluntary or waivable statutory or compulsory
|
240 |
+
licensing scheme. In all other cases the Licensor expressly
|
241 |
+
reserves any right to collect such royalties, including when
|
242 |
+
the Licensed Material is used other than for NonCommercial
|
243 |
+
purposes.
|
244 |
+
|
245 |
+
|
246 |
+
Section 3 -- License Conditions.
|
247 |
+
|
248 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
249 |
+
following conditions.
|
250 |
+
|
251 |
+
a. Attribution.
|
252 |
+
|
253 |
+
1. If You Share the Licensed Material (including in modified
|
254 |
+
form), You must:
|
255 |
+
|
256 |
+
a. retain the following if it is supplied by the Licensor
|
257 |
+
with the Licensed Material:
|
258 |
+
|
259 |
+
i. identification of the creator(s) of the Licensed
|
260 |
+
Material and any others designated to receive
|
261 |
+
attribution, in any reasonable manner requested by
|
262 |
+
the Licensor (including by pseudonym if
|
263 |
+
designated);
|
264 |
+
|
265 |
+
ii. a copyright notice;
|
266 |
+
|
267 |
+
iii. a notice that refers to this Public License;
|
268 |
+
|
269 |
+
iv. a notice that refers to the disclaimer of
|
270 |
+
warranties;
|
271 |
+
|
272 |
+
v. a URI or hyperlink to the Licensed Material to the
|
273 |
+
extent reasonably practicable;
|
274 |
+
|
275 |
+
b. indicate if You modified the Licensed Material and
|
276 |
+
retain an indication of any previous modifications; and
|
277 |
+
|
278 |
+
c. indicate the Licensed Material is licensed under this
|
279 |
+
Public License, and include the text of, or the URI or
|
280 |
+
hyperlink to, this Public License.
|
281 |
+
|
282 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
283 |
+
reasonable manner based on the medium, means, and context in
|
284 |
+
which You Share the Licensed Material. For example, it may be
|
285 |
+
reasonable to satisfy the conditions by providing a URI or
|
286 |
+
hyperlink to a resource that includes the required
|
287 |
+
information.
|
288 |
+
3. If requested by the Licensor, You must remove any of the
|
289 |
+
information required by Section 3(a)(1)(A) to the extent
|
290 |
+
reasonably practicable.
|
291 |
+
|
292 |
+
b. ShareAlike.
|
293 |
+
|
294 |
+
In addition to the conditions in Section 3(a), if You Share
|
295 |
+
Adapted Material You produce, the following conditions also apply.
|
296 |
+
|
297 |
+
1. The Adapter's License You apply must be a Creative Commons
|
298 |
+
license with the same License Elements, this version or
|
299 |
+
later, or a BY-NC-SA Compatible License.
|
300 |
+
|
301 |
+
2. You must include the text of, or the URI or hyperlink to, the
|
302 |
+
Adapter's License You apply. You may satisfy this condition
|
303 |
+
in any reasonable manner based on the medium, means, and
|
304 |
+
context in which You Share Adapted Material.
|
305 |
+
|
306 |
+
3. You may not offer or impose any additional or different terms
|
307 |
+
or conditions on, or apply any Effective Technological
|
308 |
+
Measures to, Adapted Material that restrict exercise of the
|
309 |
+
rights granted under the Adapter's License You apply.
|
310 |
+
|
311 |
+
|
312 |
+
Section 4 -- Sui Generis Database Rights.
|
313 |
+
|
314 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
315 |
+
apply to Your use of the Licensed Material:
|
316 |
+
|
317 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
318 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
319 |
+
portion of the contents of the database for NonCommercial purposes
|
320 |
+
only;
|
321 |
+
|
322 |
+
b. if You include all or a substantial portion of the database
|
323 |
+
contents in a database in which You have Sui Generis Database
|
324 |
+
Rights, then the database in which You have Sui Generis Database
|
325 |
+
Rights (but not its individual contents) is Adapted Material,
|
326 |
+
including for purposes of Section 3(b); and
|
327 |
+
|
328 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
329 |
+
all or a substantial portion of the contents of the database.
|
330 |
+
|
331 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
332 |
+
replace Your obligations under this Public License where the Licensed
|
333 |
+
Rights include other Copyright and Similar Rights.
|
334 |
+
|
335 |
+
|
336 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
337 |
+
|
338 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
339 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
340 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
341 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
342 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
343 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
344 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
345 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
346 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
347 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
348 |
+
|
349 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
350 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
351 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
352 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
353 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
354 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
355 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
356 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
357 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
358 |
+
|
359 |
+
c. The disclaimer of warranties and limitation of liability provided
|
360 |
+
above shall be interpreted in a manner that, to the extent
|
361 |
+
possible, most closely approximates an absolute disclaimer and
|
362 |
+
waiver of all liability.
|
363 |
+
|
364 |
+
|
365 |
+
Section 6 -- Term and Termination.
|
366 |
+
|
367 |
+
a. This Public License applies for the term of the Copyright and
|
368 |
+
Similar Rights licensed here. However, if You fail to comply with
|
369 |
+
this Public License, then Your rights under this Public License
|
370 |
+
terminate automatically.
|
371 |
+
|
372 |
+
b. Where Your right to use the Licensed Material has terminated under
|
373 |
+
Section 6(a), it reinstates:
|
374 |
+
|
375 |
+
1. automatically as of the date the violation is cured, provided
|
376 |
+
it is cured within 30 days of Your discovery of the
|
377 |
+
violation; or
|
378 |
+
|
379 |
+
2. upon express reinstatement by the Licensor.
|
380 |
+
|
381 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
382 |
+
right the Licensor may have to seek remedies for Your violations
|
383 |
+
of this Public License.
|
384 |
+
|
385 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
386 |
+
Licensed Material under separate terms or conditions or stop
|
387 |
+
distributing the Licensed Material at any time; however, doing so
|
388 |
+
will not terminate this Public License.
|
389 |
+
|
390 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
391 |
+
License.
|
392 |
+
|
393 |
+
|
394 |
+
Section 7 -- Other Terms and Conditions.
|
395 |
+
|
396 |
+
a. The Licensor shall not be bound by any additional or different
|
397 |
+
terms or conditions communicated by You unless expressly agreed.
|
398 |
+
|
399 |
+
b. Any arrangements, understandings, or agreements regarding the
|
400 |
+
Licensed Material not stated herein are separate from and
|
401 |
+
independent of the terms and conditions of this Public License.
|
402 |
+
|
403 |
+
|
404 |
+
Section 8 -- Interpretation.
|
405 |
+
|
406 |
+
a. For the avoidance of doubt, this Public License does not, and
|
407 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
408 |
+
conditions on any use of the Licensed Material that could lawfully
|
409 |
+
be made without permission under this Public License.
|
410 |
+
|
411 |
+
b. To the extent possible, if any provision of this Public License is
|
412 |
+
deemed unenforceable, it shall be automatically reformed to the
|
413 |
+
minimum extent necessary to make it enforceable. If the provision
|
414 |
+
cannot be reformed, it shall be severed from this Public License
|
415 |
+
without affecting the enforceability of the remaining terms and
|
416 |
+
conditions.
|
417 |
+
|
418 |
+
c. No term or condition of this Public License will be waived and no
|
419 |
+
failure to comply consented to unless expressly agreed to by the
|
420 |
+
Licensor.
|
421 |
+
|
422 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
423 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
424 |
+
that apply to the Licensor or You, including from the legal
|
425 |
+
processes of any jurisdiction or authority.
|
426 |
+
|
427 |
+
=======================================================================
|
428 |
+
|
429 |
+
Creative Commons is not a party to its public
|
430 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
431 |
+
its public licenses to material it publishes and in those instances
|
432 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
433 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
434 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
435 |
+
material is shared under a Creative Commons public license or as
|
436 |
+
otherwise permitted by the Creative Commons policies published at
|
437 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
438 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
439 |
+
of Creative Commons without its prior written consent including,
|
440 |
+
without limitation, in connection with any unauthorized modifications
|
441 |
+
to any of its public licenses or any other arrangements,
|
442 |
+
understandings, or agreements concerning use of licensed material. For
|
443 |
+
the avoidance of doubt, this paragraph does not form part of the
|
444 |
+
public licenses.
|
445 |
+
|
446 |
+
Creative Commons may be contacted at creativecommons.org.
|
NOTICE
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PUMP
|
2 |
+
Copyright (c) 2022-present NAVER Corp.
|
3 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
|
4 |
+
|
5 |
+
--------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
This project contains subcomponents with separate copyright notices and license terms.
|
8 |
+
Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
|
9 |
+
|
10 |
+
=====
|
11 |
+
|
12 |
+
pytorch/vision
|
13 |
+
https://github.com/pytorch/vision
|
14 |
+
|
15 |
+
|
16 |
+
BSD 3-Clause License
|
17 |
+
|
18 |
+
Copyright (c) Soumith Chintala 2016,
|
19 |
+
All rights reserved.
|
20 |
+
|
21 |
+
Redistribution and use in source and binary forms, with or without
|
22 |
+
modification, are permitted provided that the following conditions are met:
|
23 |
+
|
24 |
+
* Redistributions of source code must retain the above copyright notice, this
|
25 |
+
list of conditions and the following disclaimer.
|
26 |
+
|
27 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
28 |
+
this list of conditions and the following disclaimer in the documentation
|
29 |
+
and/or other materials provided with the distribution.
|
30 |
+
|
31 |
+
* Neither the name of the copyright holder nor the names of its
|
32 |
+
contributors may be used to endorse or promote products derived from
|
33 |
+
this software without specific prior written permission.
|
34 |
+
|
35 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
36 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
37 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
38 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
39 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
40 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
41 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
42 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
43 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
44 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
45 |
+
|
46 |
+
=====
|
README.md
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: PUMP
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
app_file: app.py
|
8 |
+
pinned: false
|
9 |
+
---
|
10 |
+
|
11 |
+
# PUMP: pyramidal and uniqueness matching priors for unsupervised learning of local features #
|
12 |
+

|
13 |
+
|
14 |
+
Official repository for the following [paper](https://europe.naverlabs.com/research/publications/pump-pyramidal-and-uniqueness-matching-priors-for-unsupervised-learning-of-local-features/):
|
15 |
+
|
16 |
+
```text
|
17 |
+
@inproceedings{cvpr22_pump,
|
18 |
+
author = {Jerome Revaud, Vincent Leroy, Philippe Weinzaepfel, Boris Chidlovskii},
|
19 |
+
title = {PUMP: pyramidal and uniqueness matching priors for unsupervised learning of local features},
|
20 |
+
booktitle = {CVPR},
|
21 |
+
year = {2022},
|
22 |
+
}
|
23 |
+
```
|
24 |
+

|
25 |
+
|
26 |
+
License
|
27 |
+
-------
|
28 |
+
Our code is released under the CC BY-NC-SA 4.0 License (see [LICENSE](LICENSE) for more details), available only for non-commercial use.
|
29 |
+
|
30 |
+
|
31 |
+
Requirements
|
32 |
+
------------
|
33 |
+
- Python 3.8+ equipped with standard scientific packages and PyTorch / TorchVision:
|
34 |
+
```
|
35 |
+
tqdm >= 4
|
36 |
+
PIL >= 8.1.1
|
37 |
+
numpy >= 1.19
|
38 |
+
scipy >= 1.6
|
39 |
+
torch >= 1.10.0
|
40 |
+
torchvision >= 0.9.0
|
41 |
+
matplotlib >= 3.3.4
|
42 |
+
```
|
43 |
+
- the CUDA tool kit, to compile custom CUDA kernels
|
44 |
+
```bash
|
45 |
+
cd core/cuda_deepm/
|
46 |
+
python setup.py install
|
47 |
+
```
|
48 |
+
|
49 |
+
Warping Demo
|
50 |
+
------------
|
51 |
+
|
52 |
+
```bash
|
53 |
+
python demo_warping.py
|
54 |
+
```
|
55 |
+
|
56 |
+
You should see the following result:
|
57 |
+

|
58 |
+
|
59 |
+
Test usage
|
60 |
+
----------
|
61 |
+
|
62 |
+
We provide 4 variations of the pairwise matching code, named `test_xxxscale_yyy.py`:
|
63 |
+
- xxx: `single`-scale or `multi`-scale.
|
64 |
+
Single-scale can cope with 0.75~1.33x scale difference at most.
|
65 |
+
Multi-scale version can also be rotation invariant if asked.
|
66 |
+
- yyy: recursive or not. Recursive is slower but provide denser/better outputs.
|
67 |
+
|
68 |
+
For most cases, you want to use `test_multiscale.py`:
|
69 |
+
```bash
|
70 |
+
python test_multiscale.py
|
71 |
+
--img1 path/to/img1
|
72 |
+
--img2 path/to/img2
|
73 |
+
--resize 600 # important, see below
|
74 |
+
--post-filter
|
75 |
+
--output path/to/correspondences.npy
|
76 |
+
```
|
77 |
+
|
78 |
+
It outputs a numpy binary file with the field `file_data['corres']` containing a list of correspondences.
|
79 |
+
The row format is `[x1, y1, x2, y2, score, scale_rot_code]`.
|
80 |
+
Use `core.functional.decode_scale_rot(code) --> (scale, angle_in_degrees)` to decode the `scale_rot_code`.
|
81 |
+
|
82 |
+
|
83 |
+
#### Optional parameters:
|
84 |
+
|
85 |
+
- **Prior image resize**: `--resize SIZE`
|
86 |
+
|
87 |
+
This is a very important parameter. In general, the bigger, the better (and slower).
|
88 |
+
Be wary that the memory footprint explodes with the image size.
|
89 |
+
Here is the table of maximum `--resize` values depending on the image aspect-ratio:
|
90 |
+
|
91 |
+
| Aspect-ratio | Example img sizes | GPU memory | resize |
|
92 |
+
|--------------|--------------------|------------|--------|
|
93 |
+
| 4/3 | 800x600, 1024x768 | 16 Go | 600 |
|
94 |
+
| 4/3 | 800x600, 1024x768 | 22 Go | 680 |
|
95 |
+
| 4/3 | 800x600, 1024x768 | 32 Go | 760 |
|
96 |
+
| 1/1 | 1024x1024 | 16 Go | 540 |
|
97 |
+
| 1/1 | 1024x1024 | 22 Go | 600 |
|
98 |
+
| 1/1 | 1024x1024 | 32 Go | 660 |
|
99 |
+
|
100 |
+
(Formula: `memory_in_bytes = (W1*H1*W2*H2)*1.333*2/16`)
|
101 |
+
|
102 |
+
- **Base descriptor**: `--desc {PUMP, PUMP-stytrf}`
|
103 |
+
|
104 |
+
We provide the `PUMP` descriptor from our paper, as well as `PUMP-stytrf` (with additional style-transfer training).
|
105 |
+
Defaults to `PUMP-stytrf`.
|
106 |
+
|
107 |
+
- **Scale**: `--max-scale SCALE`
|
108 |
+
|
109 |
+
By default, this value is set to 4, meaning that PUMP is _at least_ invariant to a 4x zoom-in or
|
110 |
+
zoom-out. In practically all cases, this is more than enough. You may reduce this value if you know
|
111 |
+
this is too much in order to accelerate computations.
|
112 |
+
|
113 |
+
- **Rotation**: `--max-rot DEGREES`
|
114 |
+
|
115 |
+
By default, PUMP is not rotation-invariant. To enforce rotation invariance, you need to specify
|
116 |
+
the amount of rotation it can tolerate. The more, the slower. Maximum value is 180.
|
117 |
+
If you know that images are not vertically oriented, you can just use 90 degrees.
|
118 |
+
|
119 |
+
- **post-filter**: `--post-filter "option1=val1,option2=val2,..."`
|
120 |
+
|
121 |
+
When activated, post-filtering remove spurious correspondences based on their local consistency.
|
122 |
+
See `python post_filter.py --help` for details about the possible options.
|
123 |
+
It is geometry-agnostic and naturally supports dynamic scenes.
|
124 |
+
If you want to output _pixel-dense_ correspondences (a.k.a _optical flow_), you need to post-process
|
125 |
+
the correspondences with `--post-filter densify=True`. See `demo_warping.py` for an example.
|
126 |
+
|
127 |
+
|
128 |
+
#### Visualization of results:
|
129 |
+
```bash
|
130 |
+
python -m tools.viz --img1 path/to/img1 --img2 path/to/img2 --corres path/to/correspondences.npy
|
131 |
+
```
|
132 |
+
|
133 |
+
Reproducing results on the ETH-3D dataset
|
134 |
+
-----------------------------------------
|
135 |
+
|
136 |
+
1. Download the ETH-3D dataset from [their website](https://www.eth3d.net/datasets) and extract it in `datasets/eth3d/`
|
137 |
+
|
138 |
+
2. Run the code `python run_ETH3D.py`. You should get results slightly better than reported in the paper.
|
139 |
+
|
140 |
+
|
141 |
+
Training PUMP from scratch
|
142 |
+
--------------------------
|
143 |
+
|
144 |
+
1. Download the training data with
|
145 |
+
```bash
|
146 |
+
bash download_training_data.sh
|
147 |
+
```
|
148 |
+
|
149 |
+
This consists of web images from [this paper](http://cmp.felk.cvut.cz/revisitop/) for the self-supervised loss (as in [R2D2](https://github.com/naver/r2d2))
|
150 |
+
and image pairs from the [SfM120k dataset](http://cmp.felk.cvut.cz/cnnimageretrieval/) with automatically
|
151 |
+
extracted pixel correspondences. Note that correspondences are *not* used in the loss, since the loss is
|
152 |
+
unsupervised. They are only necessary so that random cropping produces pairs of crops at least partially aligned.
|
153 |
+
Therefore, correspondences do not need to be 100% correct or even pixel-precise.
|
154 |
+
|
155 |
+
2. Run `python train.py --save-path <output_dir>/`
|
156 |
+
|
157 |
+
Note that the training code is quite rudimentary (only supports `nn.DataParallel`,
|
158 |
+
no support for `DataDistributed` at the moment, and no validation phase neither).
|
159 |
+
|
160 |
+
3. Move and rename your final checkpoint to `checkpoints/NAME.pt` and test it with
|
161 |
+
```bash
|
162 |
+
python test_multiscale.py ... --desc NAME
|
163 |
+
```
|
app.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import sys, os
|
3 |
+
import torch
|
4 |
+
import matplotlib.pylab as plt
|
5 |
+
|
6 |
+
def pump_matching(img1, img2, trained_with_st=False, scale=300, max_scale=1, max_rot=0, use_gpu=False):
|
7 |
+
|
8 |
+
use_singlescale = max_scale==1 and max_rot==0
|
9 |
+
if use_singlescale: # single
|
10 |
+
from test_singlescale import Main, arg_parser
|
11 |
+
else:
|
12 |
+
from test_multiscale import Main, arg_parser
|
13 |
+
parser = arg_parser()
|
14 |
+
|
15 |
+
args_list = ['--img1','dummy','--img2','dummy','--post-filter', '--desc','PUMP-stytrf' if trained_with_st else 'PUMP','--resize',str(scale)]
|
16 |
+
if not use_gpu:
|
17 |
+
args_list += ['--device', 'cpu']
|
18 |
+
if not use_singlescale:
|
19 |
+
args_list += ['--max-scale',str(max_scale),'--max-rot',str(max_rot)]
|
20 |
+
|
21 |
+
args = parser.parse_args(args_list)
|
22 |
+
|
23 |
+
corres = Main().run_from_args_with_images(img1, img2, args)
|
24 |
+
|
25 |
+
fig1 = plt.figure(1)
|
26 |
+
plt.imshow(img1)
|
27 |
+
ax1 = plt.gca()
|
28 |
+
ax1.axis('off')
|
29 |
+
plt.tight_layout()
|
30 |
+
|
31 |
+
fig2 = plt.figure(2)
|
32 |
+
plt.imshow(img2)
|
33 |
+
ax2 = plt.gca()
|
34 |
+
ax2.axis('off')
|
35 |
+
plt.tight_layout()
|
36 |
+
|
37 |
+
from tools.viz import plot_grid
|
38 |
+
if corres.shape[-1] > 4:
|
39 |
+
corres = corres[corres[:,4]>0,:] # select non-null correspondences
|
40 |
+
if corres.shape[0]>0: plot_grid(corres, ax1, ax2, marker='+')
|
41 |
+
|
42 |
+
img1 = None
|
43 |
+
img2 = None
|
44 |
+
|
45 |
+
return fig1, fig2
|
46 |
+
|
47 |
+
has_cuda = torch.cuda.is_available() and torch.cuda.device_count()>0
|
48 |
+
|
49 |
+
title = "PUMP local descriptor demo"
|
50 |
+
description = "This is a visualization demo for the PUMP local descriptors presented in our CVPR 2022 paper <b><a href='https://europe.naverlabs.com/research/publications/pump-pyramidal-and-uniqueness-matching-priors-for-unsupervised-learning-of-local-features/' target='_blank'>PUMP: Pyramidal and Uniqueness Matching Priors for Unsupervised Learning of Local Features</a></b>.</p><p><b>WARNING:</b> this demo runs on cpus with downscaled images, without multi-scale or multi-rotations testing, due to limited memory and computational resources, please check out our <a href='https://github.com/naver/pump' target='_blank'>original github repo</a> for these features.</p>"
|
51 |
+
|
52 |
+
article = "<p style='text-align: center'><a href='https://github.com/naver/pump' target='_blank'>Original Github Repo</a></p>"
|
53 |
+
|
54 |
+
iface = gr.Interface(
|
55 |
+
fn=pump_matching,
|
56 |
+
inputs=[
|
57 |
+
gr.inputs.Image(shape=(1024, 1024), type="pil", label="First Image"),
|
58 |
+
gr.inputs.Image(shape=(1024, 1024), type="pil", label="Second Image"),
|
59 |
+
gr.inputs.Checkbox(default=False, label="Use the model trained with style transfer"),
|
60 |
+
#gr.inputs.Slider(minimum=300, maximum=600, default=400, step=1, label="Original test scale"),
|
61 |
+
#gr.inputs.Slider(minimum=1, maximum=4, default=1, step=0.1, label="Multi Scale Testing - maximum scale (makes it slower)"),
|
62 |
+
#gr.inputs.Slider(minimum=0, maximum=180, default=0, step=45, label="Multi Rotation Testing - max rot (makes it slower)"),]
|
63 |
+
#+ ([gr.inputs.Checkbox(default=True, label='Use GPU instead of CPU')] if has_cuda else []),"""
|
64 |
+
],
|
65 |
+
outputs=[
|
66 |
+
gr.outputs.Image(type="plot", label="Matches in the first image"),
|
67 |
+
gr.outputs.Image(type="plot", label="Matches in the second image"),
|
68 |
+
],
|
69 |
+
title=title,
|
70 |
+
theme='peach',
|
71 |
+
description=description,
|
72 |
+
article=article,
|
73 |
+
examples=[
|
74 |
+
['datasets/demo_warp/mountains_src.jpg','datasets/demo_warp/mountains_tgt.jpg',False],#,400,1,0]+([True] if has_cuda else []),
|
75 |
+
]
|
76 |
+
)
|
77 |
+
iface.launch(enable_queue=True)
|
checkpoints/PUMP-stytrf.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e78a9bbbd8a6c9823265adf41b4a330f87fa58fb07832d6d56c6ae94769fd27d
|
3 |
+
size 13976029
|
checkpoints/PUMP.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0a58cf5a1a4699e087c269ec9054c35637cd056fc68a37f1ee96da6b53e0804f
|
3 |
+
size 13976029
|
core/conv_mixer.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
""" From the ICLR22 paper: Patches are all you need
|
12 |
+
https://openreview.net/pdf?id=TVHS5Y4dNvM
|
13 |
+
"""
|
14 |
+
|
15 |
+
class Residual(nn.Module):
|
16 |
+
def __init__(self, fn, stride=1):
|
17 |
+
super().__init__()
|
18 |
+
self.fn = fn
|
19 |
+
self.stride = stride
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
s = slice(None,None,self.stride)
|
23 |
+
return x[:,:,s,s] + self.fn(x)[:,:,s,s]
|
24 |
+
|
25 |
+
|
26 |
+
class ConvMixer (nn.Sequential):
|
27 |
+
""" Modified ConvMixer with convolutional layers at the bottom.
|
28 |
+
|
29 |
+
From the ICLR22 paper: Patches are all you need, https://openreview.net/pdf?id=TVHS5Y4dNvM
|
30 |
+
"""
|
31 |
+
def __init__(self, output_dim, hidden_dim,
|
32 |
+
depth=None, kernel_size=5, patch_size=8, group_size=1,
|
33 |
+
preconv=1, faster=True, relu=nn.ReLU):
|
34 |
+
|
35 |
+
assert kernel_size % 2 == 1, 'kernel_size must be odd'
|
36 |
+
output_step = 1 + faster
|
37 |
+
assert patch_size % output_step == 0, f'patch_size must be multiple of {output_step}'
|
38 |
+
self.patch_size = patch_size
|
39 |
+
|
40 |
+
hidden_dims = [hidden_dim//4]*preconv + [hidden_dim]*(depth+1)
|
41 |
+
ops = [
|
42 |
+
nn.Conv2d(3, hidden_dims[0], kernel_size=5, padding=2),
|
43 |
+
relu(),
|
44 |
+
nn.BatchNorm2d(hidden_dims[0])]
|
45 |
+
|
46 |
+
for _ in range(1,preconv):
|
47 |
+
ops += [
|
48 |
+
nn.Conv2d(hidden_dims.pop(0), hidden_dims[0], kernel_size=3, padding=1),
|
49 |
+
relu(),
|
50 |
+
nn.BatchNorm2d(hidden_dims[0])]
|
51 |
+
|
52 |
+
ops += [
|
53 |
+
nn.Conv2d(hidden_dims.pop(0), hidden_dims[0], kernel_size=patch_size, stride=patch_size),
|
54 |
+
relu(),
|
55 |
+
nn.BatchNorm2d(hidden_dims[0])]
|
56 |
+
|
57 |
+
for idim, odim in zip(hidden_dims[0:], hidden_dims[1:]):
|
58 |
+
ops += [Residual(nn.Sequential(
|
59 |
+
nn.Conv2d(idim, idim, kernel_size, groups=max(1,idim//group_size), padding=kernel_size//2),
|
60 |
+
relu(),
|
61 |
+
nn.BatchNorm2d(idim)
|
62 |
+
)),
|
63 |
+
nn.Conv2d(idim, odim, kernel_size=1),
|
64 |
+
relu(),
|
65 |
+
nn.BatchNorm2d(odim)]
|
66 |
+
ops += [
|
67 |
+
nn.Conv2d(odim, output_dim*(patch_size//output_step)**2, kernel_size=1),
|
68 |
+
nn.PixelShuffle( patch_size//output_step ),
|
69 |
+
nn.Upsample(scale_factor=output_step, mode='bilinear', align_corners=False)]
|
70 |
+
|
71 |
+
super().__init__(*ops)
|
72 |
+
|
73 |
+
def forward(self, img):
|
74 |
+
assert img.ndim == 4
|
75 |
+
B, C, H, W = img.shape
|
76 |
+
desc = super().forward(img)
|
77 |
+
return F.normalize(desc, dim=-3)
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == '__main__':
|
81 |
+
net = ConvMixer3(128, 512, 7, patch_size=4, kernel_size=9)
|
82 |
+
print(net)
|
83 |
+
|
84 |
+
img = torch.rand(2,3,256,256)
|
85 |
+
print('input.shape =', img.shape)
|
86 |
+
desc = net(img)
|
87 |
+
print('desc.shape =', desc.shape)
|
core/cuda_deepm/.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.so
|
2 |
+
_ext*
|
3 |
+
__pycache__
|
4 |
+
build
|
core/cuda_deepm/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
# run `python setup.py install`
|
6 |
+
import cuda_deepm as _kernels
|
7 |
+
|
8 |
+
__all__ = {k:v for k,v in vars(_kernels).items() if k[0] != '_'}
|
9 |
+
globals().update(__all__)
|
core/cuda_deepm/func.cpp
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright 2022-present NAVER Corp.
|
2 |
+
// CC BY-NC-SA 4.0
|
3 |
+
// Available only for non-commercial use
|
4 |
+
|
5 |
+
#include <torch/extension.h>
|
6 |
+
using namespace torch::indexing; // Slice
|
7 |
+
#include <vector>
|
8 |
+
|
9 |
+
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
10 |
+
#define MAX(x, y) ((x) < (y) ? (y) : (x))
|
11 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
12 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
13 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
14 |
+
|
15 |
+
inline Slice sl(bool x) {
|
16 |
+
if (x)
|
17 |
+
return Slice(0, -1);
|
18 |
+
else
|
19 |
+
return Slice(1, None);
|
20 |
+
}
|
21 |
+
|
22 |
+
torch::Tensor forward_agg_cuda( int level, float norm, const torch::Tensor lower,
|
23 |
+
const at::optional<at::Tensor> weights, torch::Tensor upper );
|
24 |
+
|
25 |
+
std::vector<torch::Tensor> forward_agg( int level, float norm, const torch::Tensor lower,
|
26 |
+
const at::optional<at::Tensor> weights = at::nullopt ) {
|
27 |
+
TORCH_CHECK(level >= 1, "level must be >= 1");
|
28 |
+
TORCH_CHECK(lower.dim() == 4, "input must have 4 dimensions");
|
29 |
+
const auto LH1 = lower.size(0);
|
30 |
+
const auto LW1 = lower.size(1);
|
31 |
+
const auto LH2 = lower.size(2);
|
32 |
+
const auto LW2 = lower.size(3);
|
33 |
+
if (weights) TORCH_CHECK(weights->size(0) == LH1 && weights->size(1) == LW1, "weights should have shape == lower.shape[:2]");
|
34 |
+
const auto UH1 = (level == 1) ? LH1+1 : LH1;
|
35 |
+
const auto UW1 = (level == 1) ? LW1+1 : LW1;
|
36 |
+
|
37 |
+
TORCH_CHECK(lower.is_cuda())
|
38 |
+
auto upper = torch::zeros({UH1, UW1, LH2, LW2}, lower.options());
|
39 |
+
torch::Tensor new_weights = forward_agg_cuda( level, norm, lower, weights, upper );
|
40 |
+
return {upper, new_weights};
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
torch::Tensor forward_pool_agg_cuda( int level, float norm, const torch::Tensor lower,
|
45 |
+
const at::optional<at::Tensor> weights, torch::Tensor upper );
|
46 |
+
|
47 |
+
std::vector<torch::Tensor> forward_pool_agg( int level, float norm, const torch::Tensor lower,
|
48 |
+
const at::optional<at::Tensor> weights = at::nullopt ) {
|
49 |
+
TORCH_CHECK(level >= 1, "level must be >= 1");
|
50 |
+
TORCH_CHECK(lower.dim() == 4, "input must have 4 dimensions");
|
51 |
+
const auto LH1 = lower.size(0);
|
52 |
+
const auto LW1 = lower.size(1);
|
53 |
+
const auto LH2 = lower.size(2);
|
54 |
+
const auto LW2 = lower.size(3);
|
55 |
+
if (weights) TORCH_CHECK(weights->size(0) == LH1 && weights->size(1) == LW1, "weights should have shape == lower.shape[:2]");
|
56 |
+
const auto UH1 = (level == 1) ? LH1+1 : LH1;
|
57 |
+
const auto UW1 = (level == 1) ? LW1+1 : LW1;
|
58 |
+
|
59 |
+
TORCH_CHECK(lower.is_cuda())
|
60 |
+
auto upper = torch::zeros({UH1, UW1, 1+(LH2-1)/2, 1+(LW2-1)/2}, lower.options());
|
61 |
+
torch::Tensor new_weights = forward_pool_agg_cuda( level, norm, lower, weights, upper );
|
62 |
+
return {upper, new_weights};
|
63 |
+
}
|
64 |
+
|
65 |
+
// forward declaration
|
66 |
+
void backward_agg_unpool_cuda( int level, const torch::Tensor upper, torch::Tensor lower, bool exclude_borders );
|
67 |
+
|
68 |
+
void backward_agg_unpool( int level, const torch::Tensor upper, torch::Tensor lower, bool exclude_borders = true ) {
|
69 |
+
TORCH_CHECK(level >= 1, "level must be >= 1");
|
70 |
+
TORCH_CHECK( upper.dim() == 4 && lower.dim() == 4, "inputs should be 4-dimensional" );
|
71 |
+
|
72 |
+
TORCH_CHECK(upper.is_cuda() && lower.is_cuda())
|
73 |
+
backward_agg_unpool_cuda(level, upper, lower, exclude_borders);
|
74 |
+
}
|
75 |
+
|
76 |
+
|
77 |
+
void max_pool3d_cuda( const torch::Tensor tensor, const int kernel_size, const int stride,
|
78 |
+
torch::Tensor maxima, torch::Tensor indices );
|
79 |
+
|
80 |
+
std::vector<torch::Tensor> max_pool3d( const torch::Tensor tensor, const int kernel_size, const int stride ) {
|
81 |
+
TORCH_CHECK(tensor.dim() == 4, "tensor should be 4-dimensional: BxCxHxW");
|
82 |
+
TORCH_CHECK( 1 <= kernel_size, "bad kernel size %d", kernel_size );
|
83 |
+
TORCH_CHECK( 1 <= stride, "bad stride %d", stride );
|
84 |
+
const int IB = tensor.size(0);
|
85 |
+
const int IH = tensor.size(2); // input height
|
86 |
+
const int IW = tensor.size(3); // input width
|
87 |
+
|
88 |
+
// output size
|
89 |
+
const int OH = 1 + (IH - kernel_size) / stride;
|
90 |
+
const int OW = 1 + (IW - kernel_size) / stride;
|
91 |
+
|
92 |
+
torch::Tensor maxima = torch::empty({IB, OH, OW}, tensor.options());
|
93 |
+
torch::Tensor indices = torch::empty({IB, OH, OW}, tensor.options().dtype(torch::kInt64));
|
94 |
+
|
95 |
+
if (tensor.is_cuda())
|
96 |
+
max_pool3d_cuda( tensor, kernel_size, stride, maxima, indices );
|
97 |
+
else
|
98 |
+
TORCH_CHECK(false, "CPU max_pool3d not implemented yet");
|
99 |
+
return {maxima, indices};
|
100 |
+
}
|
101 |
+
|
102 |
+
static inline float ptdot( const float* m, float x, float y ) {
|
103 |
+
return x*m[0] + y*m[1] + m[2];
|
104 |
+
}
|
105 |
+
|
106 |
+
static inline float pow2(float v) {
|
107 |
+
return v*v;
|
108 |
+
}
|
109 |
+
|
110 |
+
void merge_corres_cpu( const torch::Tensor corres, int offset, const torch::Tensor _inv_rot,
|
111 |
+
float dmax, torch::Tensor all_corres, const int all_step ) {
|
112 |
+
const int H = corres.size(0);
|
113 |
+
const int W = corres.size(1);
|
114 |
+
const float tol = 2*2; // squared
|
115 |
+
dmax *= dmax; // squared
|
116 |
+
|
117 |
+
TORCH_CHECK( _inv_rot.is_contiguous() );
|
118 |
+
const float* inv_rot = _inv_rot.data_ptr<float>();
|
119 |
+
|
120 |
+
auto corres_a = corres.accessor<float,3>();
|
121 |
+
auto all_corres_a = all_corres.accessor<float,3>();
|
122 |
+
|
123 |
+
// for each bin of the final histograms, we get the nearest-neighbour bin in corres0 and corres1
|
124 |
+
for (int j=0; j<all_corres.size(0); j++)
|
125 |
+
for (int i=0; i<all_corres.size(1); i++) {
|
126 |
+
// printf("accessing all_corres[%d,%d]", j, i);
|
127 |
+
auto all_cor = all_corres_a[j][i];
|
128 |
+
|
129 |
+
// center of the bin in the reference frame
|
130 |
+
float x = i*all_step + all_step/2;
|
131 |
+
float y = j*all_step + all_step/2;
|
132 |
+
// printf(" -> (%g,%g) in ref img", x, y);
|
133 |
+
|
134 |
+
// center of the bin on the rescaled+rotated image
|
135 |
+
float xr = ptdot( inv_rot + 0, x, y );
|
136 |
+
float yr = ptdot( inv_rot + 3, x, y );
|
137 |
+
// printf(" -> (%g,%g) in rescaled", xr, yr);
|
138 |
+
|
139 |
+
// iterate on the nearby bins
|
140 |
+
int xb = (int)(0.5+ xr/4); // rescaled+rotated desc always has step 4
|
141 |
+
int yb = (int)(0.5+ yr/4);
|
142 |
+
// printf(" -> (%d,%d) in bins\n", xb, yb);
|
143 |
+
|
144 |
+
float best = dmax;
|
145 |
+
for (int v = MAX(0,yb-1); v <= MIN(H,yb+1); v++)
|
146 |
+
for (int u = MAX(0,xb-1); u <= MIN(W,xb+1); u++) {
|
147 |
+
// assert( v >= 0 && v < corres_a.size(0) );
|
148 |
+
// assert( u >= 0 && u < corres_a.size(1) );
|
149 |
+
auto cor = corres_a[v][u];
|
150 |
+
float d = pow2(cor[offset]-x) + pow2(cor[offset+1]-y);
|
151 |
+
if( d < best ) best = d;
|
152 |
+
}
|
153 |
+
|
154 |
+
for (int v = MAX(0,yb-1); v <= MIN(H,yb+1); v++)
|
155 |
+
for (int u = MAX(0,xb-1); u <= MIN(W,xb+1); u++) {
|
156 |
+
// assert( v >= 0 && v < corres_a.size(0) );
|
157 |
+
// assert( u >= 0 && u < corres_a.size(1) );
|
158 |
+
auto cor = corres_a[v][u];
|
159 |
+
float d = pow2(cor[offset]-x) + pow2(cor[offset+1]-y);
|
160 |
+
if( d <= tol*best ) { // spatially close
|
161 |
+
// merge correspondence if score is better than actual
|
162 |
+
// printf("update all_corres[%d,%d]\n", v,u);
|
163 |
+
if( cor[4] > all_cor[4] )
|
164 |
+
for (int k = 0; k < all_corres.size(2); k++)
|
165 |
+
all_cor[k] = cor[k];
|
166 |
+
}
|
167 |
+
}
|
168 |
+
}
|
169 |
+
}
|
170 |
+
|
171 |
+
void merge_corres_cuda( const torch::Tensor corres, int offset, const torch::Tensor inv_rot,
|
172 |
+
float dmax, torch::Tensor all_corres, const int all_step );
|
173 |
+
|
174 |
+
void merge_corres( const torch::Tensor corres, int offset, const torch::Tensor rot,
|
175 |
+
torch::Tensor all_corres, const int all_step ) {
|
176 |
+
TORCH_CHECK( corres.dim() == 3 && corres.size(2) == 6, "corres.shape should be (H,W,6)" );
|
177 |
+
TORCH_CHECK( all_corres.dim() == 3 && all_corres.size(2) == 6, "all_corres.shape should be (H,W,6)" );
|
178 |
+
|
179 |
+
float dmax = 8 * torch::sqrt(torch::det(rot)).item<float>();
|
180 |
+
torch::Tensor inv_rot = torch::inverse(rot).contiguous();
|
181 |
+
|
182 |
+
if (all_corres.is_cuda())
|
183 |
+
merge_corres_cuda( corres, offset, inv_rot, dmax, all_corres, all_step );
|
184 |
+
else
|
185 |
+
merge_corres_cpu( corres, offset, inv_rot, dmax, all_corres, all_step );
|
186 |
+
}
|
187 |
+
|
188 |
+
|
189 |
+
void mask_correlations_radial_cuda( torch::Tensor corr, const torch::Tensor targets,
|
190 |
+
const float radius, const float alpha);
|
191 |
+
|
192 |
+
void mask_correlations_radial( torch::Tensor corr, const torch::Tensor targets,
|
193 |
+
const float radius, const float alpha) {
|
194 |
+
// radius: protected area in pixels around each target center
|
195 |
+
// alpha: in [0,1]. If alpha = 0: no effect. If alpha = 1: full effect.
|
196 |
+
TORCH_CHECK( corr.dim() == 4 );
|
197 |
+
TORCH_CHECK( targets.dim() == 3 );
|
198 |
+
TORCH_CHECK( targets.size(0) == corr.size(0) && targets.size(1) == corr.size(1) && targets.size(2) == 2,
|
199 |
+
"correlations and targets should have the same shape[:2]" );
|
200 |
+
|
201 |
+
if (corr.is_cuda())
|
202 |
+
mask_correlations_radial_cuda( corr, targets, radius, alpha );
|
203 |
+
else
|
204 |
+
TORCH_CHECK(false, "TODO");
|
205 |
+
}
|
206 |
+
|
207 |
+
|
208 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
209 |
+
m.def("forward_agg", &forward_agg, "forward aggregation (CUDA)");
|
210 |
+
m.def("forward_pool_agg", &forward_pool_agg, "forward pooling and aggregation (CUDA)");
|
211 |
+
m.def("backward_agg_unpool", &backward_agg_unpool, "backward sparse-conv and max-unpooling (C++ & CUDA)");
|
212 |
+
m.def("max_pool3d", &max_pool3d, "max_pool3d that can handle big inputs (CUDA)");
|
213 |
+
m.def("merge_corres_one_side", &merge_corres, "merge correspondences on CPU or GPU" );
|
214 |
+
m.def("mask_correlations_radial", &mask_correlations_radial, "mask correlations radially (CUDA)" );
|
215 |
+
}
|
core/cuda_deepm/kernels.cu
ADDED
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright 2022-present NAVER Corp.
|
2 |
+
// CC BY-NC-SA 4.0
|
3 |
+
// Available only for non-commercial use
|
4 |
+
|
5 |
+
#include <torch/extension.h>
|
6 |
+
#include <cuda.h>
|
7 |
+
#include <cuda_runtime.h>
|
8 |
+
#include <vector>
|
9 |
+
|
10 |
+
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
11 |
+
#define MAX(x, y) ((x) < (y) ? (y) : (x))
|
12 |
+
#define inf std::numeric_limits<float>::infinity()
|
13 |
+
|
14 |
+
#define CHECK_CUDA(tensor) {\
|
15 |
+
TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \
|
16 |
+
TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); }
|
17 |
+
void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));}
|
18 |
+
|
19 |
+
|
20 |
+
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 600
|
21 |
+
#define atomicMax_block atomicMax
|
22 |
+
#endif
|
23 |
+
|
24 |
+
|
25 |
+
template <typename scalar_t>
|
26 |
+
__global__ void forward_agg_cuda_kernel(
|
27 |
+
const int LH1, const int LW1, const int LH2, const int LW2,
|
28 |
+
const int gap_left, const int gap_right, float norm,
|
29 |
+
const torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> lower,
|
30 |
+
torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> upper,
|
31 |
+
const float* weights, float* new_weights ) {
|
32 |
+
|
33 |
+
const auto UH1 = LH1 + bool(!gap_left); // level 0 is smaller than other levels
|
34 |
+
const auto UW1 = LW1 + bool(!gap_left);
|
35 |
+
const auto UH2 = LH2;
|
36 |
+
const auto UW2 = LW2;
|
37 |
+
|
38 |
+
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
39 |
+
const int uw2 = idx % UW2; idx /= UW2;
|
40 |
+
const int uh2 = idx % UH2; idx /= UH2;
|
41 |
+
const int uw1 = idx % UW1; idx /= UW1;
|
42 |
+
const int uh1 = idx;
|
43 |
+
if (uh1 >= UH1) return;
|
44 |
+
|
45 |
+
// then, add the 4 child
|
46 |
+
float sumw = 0, nrm = 0, res = 0;
|
47 |
+
// #pragma unroll
|
48 |
+
for (int i = 0; i < 4; i++) {
|
49 |
+
const int v = i/2, u = i%2;
|
50 |
+
// source pixel
|
51 |
+
const int lh1 = uh1 + (1-v) * gap_left - v * gap_right;
|
52 |
+
if (lh1 < 0 || lh1 >= LH1) continue;
|
53 |
+
const int lw1 = uw1 + (1-u) * gap_left - u * gap_right;
|
54 |
+
if (lw1 < 0 || lw1 >= LW1) continue;
|
55 |
+
|
56 |
+
// load weight even if (lh2,lw2) are invalid
|
57 |
+
const float weight = weights ? weights[lh1*LW1 + lw1] : 1;
|
58 |
+
sumw += weight;
|
59 |
+
|
60 |
+
const int lh2 = uh2 + 1 - 2*v;
|
61 |
+
if (lh2 < 0 || lh2 >= LH2) continue;
|
62 |
+
const int lw2 = uw2 + 1 - 2*u;
|
63 |
+
if (lw2 < 0 || lw2 >= LW2) continue;
|
64 |
+
|
65 |
+
res += weight * lower[lh1][lw1][lh2][lw2];
|
66 |
+
nrm += weight;
|
67 |
+
}
|
68 |
+
|
69 |
+
// normalize output
|
70 |
+
nrm = sumw * (nrm < sumw ? powf(nrm/sumw, norm) : 1);
|
71 |
+
upper[uh1][uw1][uh2][uw2] = (nrm ? res / nrm : 0);
|
72 |
+
if (uh2 == 1 && uw2 == 1)
|
73 |
+
new_weights[uh1*UW1 + uw1] = sumw;
|
74 |
+
}
|
75 |
+
|
76 |
+
torch::Tensor forward_agg_cuda( int level, float norm, const torch::Tensor lower,
|
77 |
+
const at::optional<at::Tensor> weights, torch::Tensor upper ) {
|
78 |
+
CHECK_CUDA(lower);
|
79 |
+
CHECK_CUDA(upper);
|
80 |
+
if (weights) CHECK_CUDA(weights.value());
|
81 |
+
|
82 |
+
const auto UH1 = upper.size(0);
|
83 |
+
const auto UW1 = upper.size(1);
|
84 |
+
const auto UH2 = upper.size(2);
|
85 |
+
const auto UW2 = upper.size(3);
|
86 |
+
const auto LH1 = lower.size(0);
|
87 |
+
const auto LW1 = lower.size(1);
|
88 |
+
const auto LH2 = lower.size(2);
|
89 |
+
const auto LW2 = lower.size(3);
|
90 |
+
TORCH_CHECK( UH1 == LH1 + int(level==1) && UW1 == LW1 + int(level==1), "inconsistent lower and upper shapes" );
|
91 |
+
|
92 |
+
const int gap_left = (level >= 2) ? 1 << (level-2) : 0; // 0, 1, 2, 4, ...
|
93 |
+
const int gap_right= 1 << MAX(0, level-2); // 1, 1, 2, 4, ...
|
94 |
+
|
95 |
+
const int MAX_THREADS = 512; // faster than 1024 (higher SM occupancy)
|
96 |
+
const int THREADS_PER_BLOCK = MAX_THREADS;
|
97 |
+
const int N_BLOCKS = (UH1*UW1*UH2*UW2 + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
|
98 |
+
|
99 |
+
torch::Tensor new_weights = torch::zeros({UH1, UW1}, upper.options().dtype(torch::kFloat32));
|
100 |
+
|
101 |
+
// one block for each layer, one thread per local-max
|
102 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(lower.type(), "forward_agg_cuda", ([&] {
|
103 |
+
forward_agg_cuda_kernel<<<N_BLOCKS, THREADS_PER_BLOCK>>>(
|
104 |
+
LH1, LW1, LH2, LW2,
|
105 |
+
gap_left, gap_right, norm,
|
106 |
+
lower.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
|
107 |
+
upper.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
|
108 |
+
weights ? weights->data_ptr<float>() : nullptr, new_weights.data_ptr<float>() );
|
109 |
+
}));
|
110 |
+
return new_weights;
|
111 |
+
}
|
112 |
+
|
113 |
+
template <typename scalar_t>
|
114 |
+
__global__ void forward_pool_agg_cuda_kernel(
|
115 |
+
const int LH1, const int LW1, const int LH2, const int LW2,
|
116 |
+
// const int UH1, const int UW1, const int UH2, const int UW2,
|
117 |
+
const int gap_left, const int gap_right, float norm,
|
118 |
+
const torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> lower,
|
119 |
+
torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> upper,
|
120 |
+
const float* weights, float* new_weights ) {
|
121 |
+
|
122 |
+
const auto UH1 = LH1 + bool(!gap_left); // level 0 is smaller than other levels
|
123 |
+
const auto UW1 = LW1 + bool(!gap_left);
|
124 |
+
const auto UH2 = (LH2-1)/2 + 1;
|
125 |
+
const auto UW2 = (LW2-1)/2 + 1;
|
126 |
+
|
127 |
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
128 |
+
const int uw2 = idx % UW2; idx /= UW2;
|
129 |
+
const int uh2 = idx % UH2; idx /= UH2;
|
130 |
+
const int uw1 = idx % UW1; idx /= UW1;
|
131 |
+
const int uh1 = idx;
|
132 |
+
if (uh1 >= UH1) return;
|
133 |
+
|
134 |
+
// then, add the 4 child
|
135 |
+
float sumw = 0, nrm = 0, res = 0;
|
136 |
+
// #pragma unroll
|
137 |
+
for (int i = 0; i < 4; i++) {
|
138 |
+
const int v = i/2, u = i%2;
|
139 |
+
// source pixel
|
140 |
+
const int lh1 = uh1 + (1-v) * gap_left - v * gap_right;
|
141 |
+
if (lh1 < 0 || lh1 >= LH1) continue;
|
142 |
+
const int lw1 = uw1 + (1-u) * gap_left - u * gap_right;
|
143 |
+
if (lw1 < 0 || lw1 >= LW1) continue;
|
144 |
+
|
145 |
+
// load weight even if (lh2,lw2) are invalid
|
146 |
+
const float weight = weights ? weights[lh1*LW1 + lw1] : 1;
|
147 |
+
sumw += weight;
|
148 |
+
|
149 |
+
const int lh2_ = 2*(uh2 + 1 - 2*v); // position in lower
|
150 |
+
const int lw2_ = 2*(uw2 + 1 - 2*u);
|
151 |
+
float lower_max = -inf;
|
152 |
+
#pragma unroll
|
153 |
+
for (int j = -1; j <= 1; j++) {
|
154 |
+
const int lh2 = lh2_ + j;
|
155 |
+
if (lh2 < 0 || lh2 >= LH2) continue;
|
156 |
+
#pragma unroll
|
157 |
+
for (int i = -1; i <= 1; i++) {
|
158 |
+
const int lw2 = lw2_ + i;
|
159 |
+
if (lw2 < 0 || lw2 >= LW2) continue;
|
160 |
+
float l = lower[lh1][lw1][lh2][lw2];
|
161 |
+
lower_max = MAX(lower_max, l);
|
162 |
+
}}
|
163 |
+
if (lower_max == -inf) continue;
|
164 |
+
|
165 |
+
res += weight * lower_max;
|
166 |
+
nrm += weight;
|
167 |
+
}
|
168 |
+
|
169 |
+
// normalize output
|
170 |
+
nrm = sumw * (nrm < sumw ? powf(nrm/sumw, norm) : 1);
|
171 |
+
upper[uh1][uw1][uh2][uw2] = (nrm ? res / nrm : 0);
|
172 |
+
if (uh2 == 1 && uw2 == 1)
|
173 |
+
new_weights[uh1*UW1 + uw1] = sumw;
|
174 |
+
}
|
175 |
+
|
176 |
+
torch::Tensor forward_pool_agg_cuda( int level, float norm, const torch::Tensor lower,
|
177 |
+
const at::optional<at::Tensor> weights, torch::Tensor upper ) {
|
178 |
+
CHECK_CUDA(lower);
|
179 |
+
CHECK_CUDA(upper);
|
180 |
+
if (weights) CHECK_CUDA(weights.value());
|
181 |
+
|
182 |
+
const auto LH1 = lower.size(0);
|
183 |
+
const auto LW1 = lower.size(1);
|
184 |
+
const auto LH2 = lower.size(2);
|
185 |
+
const auto LW2 = lower.size(3);
|
186 |
+
const auto UH1 = upper.size(0);
|
187 |
+
const auto UW1 = upper.size(1);
|
188 |
+
const auto UH2 = upper.size(2);
|
189 |
+
const auto UW2 = upper.size(3);
|
190 |
+
TORCH_CHECK( UH1 == LH1 + int(level==1) && UW1 == LW1 + int(level==1), "inconsistent lower and upper shapes" );
|
191 |
+
TORCH_CHECK( UH2 == (LH2-1)/2+1 && UW2 == (LW2-1)/2+1, "lower level should be twice as big" );
|
192 |
+
|
193 |
+
const int gap_left = (level >= 2) ? 1 << (level-2) : 0; // 0, 1, 2, 4, ...
|
194 |
+
const int gap_right= 1 << MAX(0, level-2); // 1, 1, 2, 4, ...
|
195 |
+
|
196 |
+
const int MAX_THREADS = 512; // faster than 1024 (higher SM occupancy)
|
197 |
+
const int THREADS_PER_BLOCK = MAX_THREADS;
|
198 |
+
const int N_BLOCKS = (UH1*UW1*UH2*UW2 + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
|
199 |
+
|
200 |
+
torch::Tensor new_weights = torch::zeros({UH1, UW1}, upper.options().dtype(torch::kFloat));
|
201 |
+
|
202 |
+
// one block for each layer, one thread per local-max
|
203 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(lower.type(), "forward_pool_agg_cuda", ([&] {
|
204 |
+
forward_pool_agg_cuda_kernel<<<N_BLOCKS, THREADS_PER_BLOCK>>>(
|
205 |
+
LH1, LW1, LH2, LW2,
|
206 |
+
// UH1, UW1, UH2, UW2,
|
207 |
+
gap_left, gap_right, norm,
|
208 |
+
lower.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
|
209 |
+
upper.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
|
210 |
+
weights ? weights->data<float>() : nullptr, new_weights.data<float>() );
|
211 |
+
}));
|
212 |
+
return new_weights;
|
213 |
+
}
|
214 |
+
|
215 |
+
__device__ inline int in(int lower, int var, int upper) {
|
216 |
+
return lower <= var && var < upper;
|
217 |
+
}
|
218 |
+
__device__ inline int sl(bool b) {
|
219 |
+
return b ? 1 : -1;
|
220 |
+
}
|
221 |
+
|
222 |
+
__device__ short atomicMaxShort(short* address, short val) {
|
223 |
+
unsigned int *base_address = (unsigned int *)((size_t)address & ~3); // multiple of 4
|
224 |
+
|
225 |
+
unsigned int order_from[] = {0x0010, 0x0032}; // either bytes[0:2] or bytes[2:4]
|
226 |
+
unsigned int from = order_from[((size_t)address & 3) / 2];
|
227 |
+
|
228 |
+
unsigned int order_back[] = {0x3254, 0x5410}; // right-to-left
|
229 |
+
unsigned int back = order_back[((size_t)address & 3) / 2];
|
230 |
+
unsigned int old, assumed, max_, new_;
|
231 |
+
|
232 |
+
old = *base_address;
|
233 |
+
do {
|
234 |
+
assumed = old;
|
235 |
+
max_ = max(val, (short)__byte_perm(old, 0, from)); // extract word
|
236 |
+
new_ = __byte_perm(old, max_, back); // replace word
|
237 |
+
old = atomicCAS(base_address, assumed, new_);
|
238 |
+
} while (assumed != old);
|
239 |
+
return old;
|
240 |
+
}
|
241 |
+
|
242 |
+
template <typename scalar_t>
|
243 |
+
__device__ inline void TplAtomicMax_block( scalar_t* before, scalar_t after ) { assert(!"atomicMax not implemented for this dtype"); }
|
244 |
+
template <>
|
245 |
+
__device__ inline void TplAtomicMax_block( at::Half* before, at::Half after ) { atomicMaxShort( (int16_t*)before, *(int16_t*)&after ); }
|
246 |
+
template <>
|
247 |
+
__device__ inline void TplAtomicMax_block( float* before, float after ) { atomicMax_block( (int32_t*)before, *(int32_t*)&after ); }
|
248 |
+
|
249 |
+
template <typename scalar_t>
|
250 |
+
__global__ void backward_agg_unpool_cuda_kernel(
|
251 |
+
const int UH1, const int UW1,
|
252 |
+
const int UH2, const int UW2,
|
253 |
+
const int LH2, const int LW2,
|
254 |
+
const int gap_left, const int gap_right,
|
255 |
+
const torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> upper,
|
256 |
+
torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> lower ) {
|
257 |
+
|
258 |
+
/* Each block is going to take care of a single layer, i.e. lower[:,:,0::2,0::2].
|
259 |
+
the first thread is allocating some global memory and then frees it later.
|
260 |
+
*/
|
261 |
+
// const int LH1 = gridDim.x;
|
262 |
+
// const int LW1 = gridDim.y;
|
263 |
+
const int lh1 = blockIdx.y;
|
264 |
+
const int lw1 = blockIdx.x;
|
265 |
+
const int UHW2 = UH2 * UW2; // upper layer size
|
266 |
+
|
267 |
+
__shared__ float* _shared_addr;
|
268 |
+
if (threadIdx.x == 0)
|
269 |
+
do{ _shared_addr = new float [2*UHW2]; } // for each upper place, we have (best, bestp)
|
270 |
+
while(!_shared_addr); // waiting for memory to be available...
|
271 |
+
__syncthreads();
|
272 |
+
|
273 |
+
float * layer_best = _shared_addr;
|
274 |
+
int * layer_bestp = (int*)(_shared_addr+1); //UHW);
|
275 |
+
assert( layer_best );
|
276 |
+
|
277 |
+
/* First pass: we recover the position and values of all local maxima in the layer
|
278 |
+
*/
|
279 |
+
for (int idx = threadIdx.x; idx < UHW2; idx += blockDim.x) {
|
280 |
+
const int ux = idx % UW2;
|
281 |
+
const int uy = idx / UW2;
|
282 |
+
const int lx = 2*ux; // lower pos from upper pos
|
283 |
+
const int ly = 2*uy;
|
284 |
+
|
285 |
+
// argmax my local minima
|
286 |
+
float best = -inf;
|
287 |
+
int bestp = 0;
|
288 |
+
#pragma unroll
|
289 |
+
for (int j_= -1; j_<= 1; j_++) {
|
290 |
+
const int j = ly + j_;
|
291 |
+
if (j < 0 || j >= LH2) continue;
|
292 |
+
#pragma unroll
|
293 |
+
for (int i_= -1; i_<= 1; i_++) {
|
294 |
+
const int i = lx + i_;
|
295 |
+
if (i < 0 || i >= LW2) continue;
|
296 |
+
float cur = lower[lh1][lw1][j][i];
|
297 |
+
if (cur > best) { best = cur; bestp = j*LW2+i; }
|
298 |
+
}}
|
299 |
+
layer_best[2*idx] = best;
|
300 |
+
layer_bestp[2*idx] = bestp;
|
301 |
+
}
|
302 |
+
|
303 |
+
__syncthreads();
|
304 |
+
|
305 |
+
/* Second pass: we update the local maxima according to the upper layer
|
306 |
+
*/
|
307 |
+
for (int idx = threadIdx.x; idx < UHW2; idx += blockDim.x) {
|
308 |
+
const int ux = idx % UW2;
|
309 |
+
const int uy = idx / UW2;
|
310 |
+
|
311 |
+
// max-pool the additional value from the upper layer
|
312 |
+
scalar_t add = 0;
|
313 |
+
for (int v = -gap_left; v <= gap_right; v += gap_right+gap_left) {
|
314 |
+
for (int u = -gap_left; u <= gap_right; u += gap_right+gap_left) {
|
315 |
+
const int uh1 = lh1 + v, uw1 = lw1 + u;
|
316 |
+
const int uh2 = uy+sl(v>0), uw2 = ux+sl(u>0);
|
317 |
+
if (in(0, uh1, UH1) && in(0, uw1, UW1) && in(0, uh2, UH2) && in(0, uw2, UW2))
|
318 |
+
add = MAX(add, upper[uh1][uw1][uh2][uw2]);
|
319 |
+
}}
|
320 |
+
|
321 |
+
// grab local maxima
|
322 |
+
float best = layer_best[2*idx];
|
323 |
+
int bestp = layer_bestp[2*idx];
|
324 |
+
const int lx = bestp % LW2;
|
325 |
+
const int ly = bestp / LW2;
|
326 |
+
|
327 |
+
// printf("UH=%d,UW=%d: uy=%d,ux=%d --> best=%g at ly=%d,lx=%d\n", UH,UW, uy,ux, best, ly,lx);
|
328 |
+
scalar_t* before = & lower[lh1][lw1][ly][lx];
|
329 |
+
scalar_t after = best + add;
|
330 |
+
TplAtomicMax_block<scalar_t>( before, after );
|
331 |
+
}
|
332 |
+
|
333 |
+
__syncthreads();
|
334 |
+
|
335 |
+
if (threadIdx.x == 0)
|
336 |
+
delete _shared_addr;
|
337 |
+
}
|
338 |
+
|
339 |
+
void backward_agg_unpool_cuda( int level, const torch::Tensor upper, torch::Tensor lower, bool exclude_borders ) {
|
340 |
+
CHECK_CUDA(lower);
|
341 |
+
CHECK_CUDA(upper);
|
342 |
+
|
343 |
+
const auto UH1 = upper.size(0);
|
344 |
+
const auto UW1 = upper.size(1);
|
345 |
+
const auto UH2 = upper.size(2);
|
346 |
+
const auto UW2 = upper.size(3);
|
347 |
+
const auto LH1 = lower.size(0);
|
348 |
+
const auto LW1 = lower.size(1);
|
349 |
+
const auto LH2 = lower.size(2);
|
350 |
+
const auto LW2 = lower.size(3);
|
351 |
+
TORCH_CHECK( UH1 == LH1 + int(level==1) && UW1 == LW1 + int(level==1), "inconsistent lower and upper shapes" );
|
352 |
+
const int xb = exclude_borders; // local_argmax cannot reach the bottom and right borders
|
353 |
+
|
354 |
+
const int gap_left = (level >= 2) ? 1 << (level-2) : 0; // 0, 1, 2, 4, ...
|
355 |
+
const int gap_right= 1 << MAX(0, level-2); // 1, 1, 2, 4, ...
|
356 |
+
|
357 |
+
const int64_t MAX_THREADS = 1024;
|
358 |
+
const int64_t THREADS_PER_LAYER = MIN(UH2*UW2, MAX_THREADS);
|
359 |
+
|
360 |
+
// one block for each layer, one thread per local-max
|
361 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(upper.type(), "backward_agg_unpool_cuda", ([&] {
|
362 |
+
backward_agg_unpool_cuda_kernel<<<dim3(LW1,LH1), THREADS_PER_LAYER>>>(
|
363 |
+
UH1, UW1, UH2, UW2, LH2-xb, LW2-xb,
|
364 |
+
gap_left, gap_right,
|
365 |
+
upper.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
|
366 |
+
lower.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>());
|
367 |
+
}));
|
368 |
+
CHECK_KERNEL();
|
369 |
+
}
|
370 |
+
|
371 |
+
template <typename scalar_t>
|
372 |
+
__global__ void max_pool3d_cuda_kernel(
|
373 |
+
const int BS, const int NC, const int IH, const int IW, const int OH, const int OW,
|
374 |
+
const int ks, const int stride,
|
375 |
+
const torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> tensor,
|
376 |
+
torch::PackedTensorAccessor64<scalar_t,3,torch::RestrictPtrTraits> maxima,
|
377 |
+
torch::PackedTensorAccessor64<int64_t, 3,torch::RestrictPtrTraits> indices ) {
|
378 |
+
|
379 |
+
// each thread takes care of one output
|
380 |
+
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
381 |
+
const int x = idx % OW; idx /= OW;
|
382 |
+
const int y = idx % OH; idx /= OH;
|
383 |
+
const int b = idx;
|
384 |
+
if (b >= BS) return;
|
385 |
+
|
386 |
+
float best = -inf;
|
387 |
+
int64_t best_pos = 0;
|
388 |
+
for (int64_t c = 0; c < NC; c++) {
|
389 |
+
for (int j = stride*y; j < stride*y+ks; j++) {
|
390 |
+
for (int i = stride*x; i < stride*x+ks; i++) {
|
391 |
+
// assert( b < BS and c < NC and j < IH and i < IW );
|
392 |
+
float cur = tensor[b][c][j][i];
|
393 |
+
if (cur > best) {best = cur; best_pos = (c*IH + j)*IW+ i; }
|
394 |
+
}}}
|
395 |
+
|
396 |
+
// assert( b < BS and y < OH and x < OW );
|
397 |
+
maxima [b][y][x] = best;
|
398 |
+
indices[b][y][x] = best_pos;
|
399 |
+
}
|
400 |
+
|
401 |
+
void max_pool3d_cuda( const torch::Tensor tensor, const int kernel_size, const int stride,
|
402 |
+
torch::Tensor maxima, torch::Tensor indices ) {
|
403 |
+
CHECK_CUDA(tensor);
|
404 |
+
TORCH_CHECK(tensor.dim() == 4, "tensor should be 4-dimensional: BxCxHxW");
|
405 |
+
const int BS = tensor.size(0);
|
406 |
+
const int NC = tensor.size(1);
|
407 |
+
const int IH = tensor.size(2); // input height
|
408 |
+
const int IW = tensor.size(3); // input width
|
409 |
+
|
410 |
+
// output size
|
411 |
+
TORCH_CHECK( maxima.sizes() == indices.sizes(), "maxima and indices should have the same shape" );
|
412 |
+
TORCH_CHECK( BS == maxima.size(0), "bad batch size" );
|
413 |
+
const int OH = maxima.size(1);
|
414 |
+
const int OW = maxima.size(2);
|
415 |
+
|
416 |
+
const int64_t THREADS_PER_LAYER = 512;
|
417 |
+
const int64_t N_BLOCKS = (BS*OH*OW + THREADS_PER_LAYER-1) / THREADS_PER_LAYER;
|
418 |
+
|
419 |
+
// one block for each layer, one thread per local-max
|
420 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor.type(), "max_pool3d_cuda", ([&] {
|
421 |
+
max_pool3d_cuda_kernel<<<N_BLOCKS, THREADS_PER_LAYER>>>(
|
422 |
+
BS, NC, IH, IW, OH, OW, kernel_size, stride,
|
423 |
+
tensor. packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
|
424 |
+
maxima. packed_accessor64<scalar_t,3,torch::RestrictPtrTraits>(),
|
425 |
+
indices.packed_accessor64<int64_t,3,torch::RestrictPtrTraits>());
|
426 |
+
}));
|
427 |
+
}
|
428 |
+
|
429 |
+
|
430 |
+
__device__ inline float ptdot( const float* m, float x, float y ) {
|
431 |
+
return x*m[0] + y*m[1] + m[2];
|
432 |
+
}
|
433 |
+
|
434 |
+
__device__ inline float sqr(float v) {
|
435 |
+
return v*v;
|
436 |
+
}
|
437 |
+
|
438 |
+
|
439 |
+
__global__ void merge_corres_cuda_kernel(
|
440 |
+
const int OH, const int OW, const int OZ, const int IH, const int IW,
|
441 |
+
const float dmax2, int offset, const float* inv_rot, const int all_step,
|
442 |
+
const torch::PackedTensorAccessor32<float,3,torch::RestrictPtrTraits> corres_a,
|
443 |
+
torch::PackedTensorAccessor32<float,3,torch::RestrictPtrTraits> all_corres_a ) {
|
444 |
+
|
445 |
+
// each thread takes care of one output
|
446 |
+
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
447 |
+
const int i = idx % OW; idx /= OW;
|
448 |
+
const int j = idx;
|
449 |
+
if (j >= OH) return;
|
450 |
+
|
451 |
+
const float tol2 = 2*2; // squared
|
452 |
+
auto all_cor = all_corres_a[j][i];
|
453 |
+
|
454 |
+
// center of the bin in the reference frame
|
455 |
+
float x = i*all_step + all_step/2;
|
456 |
+
float y = j*all_step + all_step/2;
|
457 |
+
|
458 |
+
// center of the bin on the rescaled+rotated image
|
459 |
+
float xr = ptdot( inv_rot + 0, x, y );
|
460 |
+
float yr = ptdot( inv_rot + 3, x, y );
|
461 |
+
|
462 |
+
// iterate on the nearby bins
|
463 |
+
int xb = (int)(0.5+ xr/4); // rescaled+rotated desc always has step 4
|
464 |
+
int yb = (int)(0.5+ yr/4);
|
465 |
+
|
466 |
+
float best = dmax2;
|
467 |
+
#pragma unroll
|
468 |
+
for (int _v = -1; _v <= 1; _v++) {
|
469 |
+
#pragma unroll
|
470 |
+
for (int _u = -1; _u <= 1; _u++) {
|
471 |
+
const int v = yb+_v, u = xb+_u;
|
472 |
+
if (!(in(0, v, IH) && in(0, u, IW))) continue;
|
473 |
+
auto cor = corres_a[v][u];
|
474 |
+
float d = sqr(cor[offset]-x) + sqr(cor[offset+1]-y);
|
475 |
+
if (d < best) best = d;
|
476 |
+
}}
|
477 |
+
|
478 |
+
#pragma unroll
|
479 |
+
for (int _v = -1; _v <= 1; _v++) {
|
480 |
+
#pragma unroll
|
481 |
+
for (int _u = -1; _u <= 1; _u++) {
|
482 |
+
const int v = yb+_v, u = xb+_u;
|
483 |
+
if (!(in(0, v, IH) && in(0, u, IW))) continue;
|
484 |
+
auto cor = corres_a[v][u];
|
485 |
+
float d = sqr(cor[offset]-x) + sqr(cor[offset+1]-y);
|
486 |
+
if (d <= tol2*best) { // spatially close
|
487 |
+
// merge correspondence if score is better than actual
|
488 |
+
if (cor[4] > all_cor[4])
|
489 |
+
for (int k = 0; k < OZ; k++) all_cor[k] = cor[k];
|
490 |
+
}
|
491 |
+
}}
|
492 |
+
}
|
493 |
+
|
494 |
+
void merge_corres_cuda( const torch::Tensor corres, const int offset, const torch::Tensor _inv_rot,
|
495 |
+
const float dmax, torch::Tensor all_corres, const int all_step ) {
|
496 |
+
CHECK_CUDA( corres );
|
497 |
+
CHECK_CUDA( all_corres );
|
498 |
+
CHECK_CUDA( _inv_rot );
|
499 |
+
TORCH_CHECK(_inv_rot.is_contiguous(), "inv_rot should be contiguous" );
|
500 |
+
|
501 |
+
const int IH = corres.size(0);
|
502 |
+
const int IW = corres.size(1);
|
503 |
+
const int IZ = corres.size(2);
|
504 |
+
const int OH = all_corres.size(0);
|
505 |
+
const int OW = all_corres.size(1);
|
506 |
+
const int OZ = all_corres.size(2);
|
507 |
+
TORCH_CHECK( IZ == OZ, "corres and all_corres should have the same shape[2]" );
|
508 |
+
|
509 |
+
const int THREADS_PER_LAYER = 512;
|
510 |
+
const int N_BLOCKS = (OH * OW + THREADS_PER_LAYER-1) / THREADS_PER_LAYER;
|
511 |
+
|
512 |
+
merge_corres_cuda_kernel<<<N_BLOCKS, THREADS_PER_LAYER>>>(
|
513 |
+
OH, OW, OZ, IH, IW, dmax*dmax, offset, _inv_rot.data_ptr<float>(), all_step,
|
514 |
+
corres.packed_accessor32<float,3,torch::RestrictPtrTraits>(),
|
515 |
+
all_corres.packed_accessor32<float,3,torch::RestrictPtrTraits>());
|
516 |
+
CHECK_KERNEL();
|
517 |
+
}
|
518 |
+
|
519 |
+
|
520 |
+
template <typename scalar_t>
|
521 |
+
__global__ void mask_correlations_radial_cuda_kernel(
|
522 |
+
float radius, const float alpha,
|
523 |
+
const torch::PackedTensorAccessor32<float,3,torch::RestrictPtrTraits> targets,
|
524 |
+
torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> corr ) {
|
525 |
+
|
526 |
+
#define H1 ((int)corr.size(0))
|
527 |
+
#define W1 ((int)corr.size(1))
|
528 |
+
#define H2 ((int)corr.size(2))
|
529 |
+
#define W2 ((int)corr.size(3))
|
530 |
+
|
531 |
+
// each block takes care of one layer corr[j,i,:,:]
|
532 |
+
const int j = blockIdx.x / W1;
|
533 |
+
const int i = blockIdx.x % W1;
|
534 |
+
if (j >= H1) return;
|
535 |
+
|
536 |
+
// read the target center
|
537 |
+
const float cx = targets[j][i][0];
|
538 |
+
const float cy = targets[j][i][1];
|
539 |
+
if (cx != cx || cy != cy) return; // undefined center
|
540 |
+
radius *= radius; // squared
|
541 |
+
const float alpha_out = (alpha > 1 ? 1 : alpha);
|
542 |
+
const float alpha_in = (alpha < 1 ? 1 : alpha);
|
543 |
+
|
544 |
+
for (int idx = threadIdx.x; idx < H2*W2; idx += blockDim.x) {
|
545 |
+
const int v = idx / W2;
|
546 |
+
const int u = idx % W2;
|
547 |
+
|
548 |
+
// compute weighting
|
549 |
+
float dis2 = sqr(u - cx) + sqr(v - cy);
|
550 |
+
float mul = alpha_in;
|
551 |
+
if (dis2 > radius)
|
552 |
+
mul = 1 - alpha_out*(1 - radius / dis2);
|
553 |
+
|
554 |
+
corr[j][i][v][u] *= mul;
|
555 |
+
}
|
556 |
+
}
|
557 |
+
|
558 |
+
void mask_correlations_radial_cuda( torch::Tensor corr, const torch::Tensor targets,
|
559 |
+
const float radius, const float alpha) {
|
560 |
+
CHECK_CUDA( corr );
|
561 |
+
CHECK_CUDA( targets );
|
562 |
+
|
563 |
+
const int THREADS_PER_LAYER = 512;
|
564 |
+
const int N_BLOCKS = H1*W1;
|
565 |
+
|
566 |
+
#undef H1
|
567 |
+
#undef W1
|
568 |
+
#undef H2
|
569 |
+
#undef W2
|
570 |
+
|
571 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(corr.type(), "mask_correlations_radial_cuda", ([&] {
|
572 |
+
mask_correlations_radial_cuda_kernel<<<N_BLOCKS, THREADS_PER_LAYER>>>(
|
573 |
+
radius, alpha,
|
574 |
+
targets.packed_accessor32<float,3,torch::RestrictPtrTraits>(),
|
575 |
+
corr.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>());
|
576 |
+
}));
|
577 |
+
CHECK_KERNEL();
|
578 |
+
}
|
core/cuda_deepm/setup.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from setuptools import setup
|
6 |
+
from torch import cuda
|
7 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
8 |
+
|
9 |
+
# if you want to compile for all possible CUDA architectures
|
10 |
+
all_cuda_archs = [] #cuda.get_gencode_flags().replace('compute=','arch=').split()
|
11 |
+
|
12 |
+
setup(
|
13 |
+
name='cuda_deepm',
|
14 |
+
ext_modules = [
|
15 |
+
CUDAExtension(
|
16 |
+
name = 'cuda_deepm',
|
17 |
+
sources = ["func.cpp", "kernels.cu"],
|
18 |
+
extra_compile_args = dict(nvcc=['-O2']+all_cuda_archs, cxx=['-O2'])
|
19 |
+
)
|
20 |
+
],
|
21 |
+
cmdclass = {
|
22 |
+
'build_ext': BuildExtension
|
23 |
+
})
|
24 |
+
|
core/functional.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
def affmul( aff, vecs ):
|
12 |
+
""" affine multiplication:
|
13 |
+
computes aff @ vecs.T """
|
14 |
+
if aff is None: return vecs
|
15 |
+
if isinstance(aff, (tuple,list)) or aff.ndim==3:
|
16 |
+
assert len(aff) == 2
|
17 |
+
assert 4 <= vecs.shape[-1], bb()
|
18 |
+
vecs = vecs.clone() if isinstance(vecs, torch.Tensor) else vecs.copy()
|
19 |
+
vecs[...,0:2] = affmul(aff[0], vecs[...,0:2])
|
20 |
+
vecs[...,2:4] = affmul(aff[1], vecs[...,2:4])
|
21 |
+
return vecs
|
22 |
+
else:
|
23 |
+
assert vecs.shape[-1] == 2, bb()
|
24 |
+
assert aff.shape == (2,3) or (aff.shape==(3,3) and
|
25 |
+
aff[2,0] == aff[2,1] == 0 and aff[2,2] == 1), bb()
|
26 |
+
return (vecs @ aff[:2,:2].T) + aff[:2,2]
|
27 |
+
|
28 |
+
|
29 |
+
def imresize( img, max_size, mode='area' ):
|
30 |
+
# trf: cur_pix --> old_pix
|
31 |
+
img, trf = img if isinstance(img,tuple) else (img, torch.eye(3,device=img.device))
|
32 |
+
|
33 |
+
shape = img.shape[-2:]
|
34 |
+
if max_size > 0 and max(shape) > max_size:
|
35 |
+
new_shape = tuple(i * max_size // max(shape) for i in shape)
|
36 |
+
img = F.interpolate( img[None].float(), size=new_shape, mode=mode )[0]
|
37 |
+
img.clamp_(min=0, max=255)
|
38 |
+
sca = torch.diag(torch.tensor((shape[0]/new_shape[0],shape[1]/new_shape[1],1), device=img.device))
|
39 |
+
img = img.byte()
|
40 |
+
trf = trf @ sca # undo sca first
|
41 |
+
|
42 |
+
return img, trf
|
43 |
+
|
44 |
+
|
45 |
+
def rotate_img( img, angle, crop=False ):
|
46 |
+
if angle in (0, 90, 180, 270):
|
47 |
+
return rotate_img_90(img,angle)
|
48 |
+
|
49 |
+
img, trf = img
|
50 |
+
assert trf.shape == (3,3)
|
51 |
+
|
52 |
+
def centered_rotation(rotation, shape, **device):
|
53 |
+
# rotation matrix
|
54 |
+
# pt_in_original_image = rot * pt_in_rotated_image
|
55 |
+
angle = rotation * np.pi / 180
|
56 |
+
c, s = np.cos(angle), np.sin(angle)
|
57 |
+
rot = torch.tensor([(c, -s, 0), (s, c, 0), (0, 0, 1)], dtype=torch.float32, **device)
|
58 |
+
|
59 |
+
# determine center of rotation before
|
60 |
+
H, W = shape
|
61 |
+
c_before = torch.tensor((W,H), **device) / 2
|
62 |
+
if crop:
|
63 |
+
c_after = c_before
|
64 |
+
rot_size = (W,H)
|
65 |
+
else:
|
66 |
+
# enlarge image to fit everything
|
67 |
+
corners = torch.tensor([(0, W, W, 0), (0, 0, H, H)], dtype=torch.float32, **device)
|
68 |
+
corners = affmul(rot, corners.T).T
|
69 |
+
rot_size = (corners.max(dim=1).values - corners.min(dim=1).values + 0.5).int()
|
70 |
+
rot_size = (rot_size // 4) * 4 # legacy
|
71 |
+
c_after = rot_size / 2
|
72 |
+
|
73 |
+
rot[:2,2] = c_before - affmul(rot, c_after) # fix translation
|
74 |
+
return rot, tuple(rot_size)[::-1]
|
75 |
+
|
76 |
+
C, H, W = img.shape
|
77 |
+
rot, (OH, OW) = centered_rotation(angle, (H,W), device=img.device)
|
78 |
+
|
79 |
+
# pt_in_original_image = rot * pt_in_rotated_image
|
80 |
+
# but pytorch works in [-1,1] coordinates... annoying
|
81 |
+
# pt_in_original_1_1 = orig_px_to_1_1 * rot * rotated_1_1_to_px * pt_in_rotated_1_1
|
82 |
+
_1_1_to_px = lambda W,H: torch.tensor(((W/2, 0, W/2), (0, H/2, H/2), (0, 0, 1)), device=img.device)
|
83 |
+
theta = torch.inverse(_1_1_to_px(W-1,H-1)) @ rot @ _1_1_to_px(OW-1,OH-1)
|
84 |
+
|
85 |
+
grid = F.affine_grid(theta[None,:2], (1, C, OH, OW), align_corners=True)
|
86 |
+
res = F.grid_sample(img[None].float(), grid, align_corners=True).to(dtype=img.dtype)[0]
|
87 |
+
return res, trf @ rot
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
def rotate_img_90( img, angle ):
|
92 |
+
""" Rotate an image by a multiple of 90 degrees using simple transpose and flip ops.
|
93 |
+
img = tuple( image, existing_trf )
|
94 |
+
existing_trf: current --> old
|
95 |
+
"""
|
96 |
+
angle = angle % 360
|
97 |
+
assert angle in (0, 90, 180, 270), 'cannot handle rotation other than multiple of 90 degrees'
|
98 |
+
img, trf = img
|
99 |
+
assert trf.shape == (3,3)
|
100 |
+
|
101 |
+
if isinstance(img, np.ndarray):
|
102 |
+
assert img.ndim == 3 and 1 <= img.shape[2] <= 3
|
103 |
+
new, x, y = np.float32, 1, 0
|
104 |
+
flip = lambda i,d: np.flip(i,axis=d)
|
105 |
+
elif isinstance(img, torch.Tensor):
|
106 |
+
assert img.ndim == 3 and 1 <= img.shape[0] <= 3
|
107 |
+
new, x, y = trf.new, -1, -2
|
108 |
+
flip = lambda i,d: i.flip(dims=[d])
|
109 |
+
H, W = img.shape[y], img.shape[x]
|
110 |
+
|
111 |
+
if angle == 90:
|
112 |
+
# point 0,0 --> (0, H-1); W-1,0 --> 0,0
|
113 |
+
img = flip(img.swapaxes(x,y),y)
|
114 |
+
trf = trf @ new([[0,-1,W-1],[1,0,0],[0,0,1]]) # inverse transform: new --> current
|
115 |
+
if angle == 180:
|
116 |
+
# point 0,0 --> (W-1, H-1)
|
117 |
+
img = flip(flip(img,x),y)
|
118 |
+
trf = trf @ new([[-1,0,W-1],[0,-1,H-1],[0,0,1]]) # inverse transform: new --> current
|
119 |
+
if angle == 270:
|
120 |
+
# point 0,0 --> (H-1, 0); 0,H-1 --> 0,0
|
121 |
+
img = flip(img.swapaxes(x,y),x)
|
122 |
+
trf = trf @ new([[0,1,0],[-1,0,H-1],[0,0,1]]) # inverse transform: new --> current
|
123 |
+
return img, trf
|
124 |
+
|
125 |
+
|
126 |
+
def encode_scale_rot(scale, rot):
|
127 |
+
s = np.int32(np.rint(np.log(scale) / (0.5*np.log(2))))
|
128 |
+
r = np.int32(np.rint(((-rot) % 360) / 45)) % 8
|
129 |
+
return 8*s + (r%8)
|
130 |
+
|
131 |
+
def decode_scale_rot( code ):
|
132 |
+
s = code // 8
|
133 |
+
r = (code % 8)
|
134 |
+
return 2 ** (s/2), -((45 * r + 180) % 360 - 180)
|
135 |
+
|
136 |
+
|
137 |
+
def normalized_corr(patches, img, padding='ncc', extra_patch=False, ret_norms=False):
|
138 |
+
assert patches.ndim == 4, 'patches shape must be (H*W, C, K, K)'
|
139 |
+
P, C, K, K = patches.shape
|
140 |
+
assert img.ndim == 3 and img.shape[0] == C, 'img shape must be (C, W, H)'
|
141 |
+
eps = torch.finfo(patches.dtype).tiny
|
142 |
+
|
143 |
+
# normalize on patches side
|
144 |
+
norms = patches.view(P,-1).norm(dim=-1)
|
145 |
+
patches = patches / norms[:,None,None,None].clamp(min=eps)
|
146 |
+
|
147 |
+
# convolve normalized patches on unnormalized image
|
148 |
+
ninth = 0
|
149 |
+
if padding == 'ninth':
|
150 |
+
ninth = img[:,-1].mean() # ninth dimension
|
151 |
+
img = F.pad(img[None], (K//2,K//2)*2, mode='constant', value=ninth)[0]
|
152 |
+
|
153 |
+
corr = F.conv2d(img[None], patches, padding=0, bias=None)[0]
|
154 |
+
|
155 |
+
# normalize on img's side
|
156 |
+
ones = patches.new_ones((1, C, K, K))
|
157 |
+
local_norm = torch.sqrt(F.conv2d(img[None]**2, ones))[0]
|
158 |
+
corr /= local_norm
|
159 |
+
|
160 |
+
# normalize on patches' side (image borders)
|
161 |
+
if padding == 'ncc':
|
162 |
+
local_norm = torch.sqrt(F.conv2d(ones, patches**2, padding=2))[0]
|
163 |
+
local_norm.clamp_(min=eps)
|
164 |
+
for j in range(-2, 3):
|
165 |
+
for i in range(-2,3):
|
166 |
+
if i == j == 2: continue # normal case is already normalized
|
167 |
+
if i == 2: i = slice(2,-2)
|
168 |
+
if j == 2: j = slice(2,-2)
|
169 |
+
corr[:,j,i] /= local_norm[:,j,i]
|
170 |
+
|
171 |
+
return (corr, norms) if ret_norms else corr
|
172 |
+
|
173 |
+
|
174 |
+
def true_corr_shape( corr_shape, level ):
|
175 |
+
H1, W1, H2, W2 = corr_shape[-4:]
|
176 |
+
if level > 0: # recover true size
|
177 |
+
H1, W1 = H1-1, W1-1
|
178 |
+
return corr_shape[:-4] + (H1, W1, H2, W2)
|
179 |
+
|
180 |
+
def children(level, H1, W1, H2, W2):
|
181 |
+
""" level: parent level (> 1) """
|
182 |
+
gap = 2**(level-2)
|
183 |
+
# @ level 1: gap=0.5 (parent at x=1 has children at x=[0.5, 1.5])
|
184 |
+
# @ level 2: gap=1 (parent at x=1 has children at x=[0, 2])
|
185 |
+
# @ level 3: gap=2 (parent at x=2 has children at x=[0, 4])
|
186 |
+
# etc.
|
187 |
+
|
188 |
+
def ravel_child(x, y):
|
189 |
+
# x,y is he center of the child patch
|
190 |
+
inside = (0 <= x <= W1) and (0 <= y <= H1)
|
191 |
+
if gap < 1:
|
192 |
+
assert x % 1 == y % 1 == 0.5, bb()
|
193 |
+
return int((x-0.5) + (y-0.5) * W1) if inside else -1
|
194 |
+
else:
|
195 |
+
assert x % 1 == y % 1 == 0, bb()
|
196 |
+
return int(x + y * (W1+1)) if inside else -1
|
197 |
+
|
198 |
+
# 4 children for each parent patch (top-left, top-right, bot-left, bot-right, -1 = None)
|
199 |
+
parents = []
|
200 |
+
for h in range(H1+1):
|
201 |
+
for w in range(W1+1):
|
202 |
+
# enumerate the 4 children for this patch
|
203 |
+
children = [ravel_child(w + gap*tx, h + gap*ty) for ty in (-1,1) for tx in (-1,1)]
|
204 |
+
parents.append(children)
|
205 |
+
|
206 |
+
return torch.tensor(parents, dtype=torch.int64)
|
207 |
+
|
208 |
+
|
209 |
+
def sparse_conv(level, corr, weights=None, reverse=False, norm=0.9):
|
210 |
+
H1, W1, H2, W2 = true_corr_shape(corr.shape, level-1 + reverse)
|
211 |
+
parents = children(level, H1, W1, H2, W2).to(corr.device)
|
212 |
+
n_parents = len(parents)
|
213 |
+
|
214 |
+
# perform the sparse convolution 'manually'
|
215 |
+
# since sparse convolutions are not implemented in pytorch currently
|
216 |
+
corr = corr.view(-1, *corr.shape[-2:])
|
217 |
+
if not reverse:
|
218 |
+
res = corr.new_zeros((n_parents+1,)+corr.shape[-2:]) # last one = garbage channel
|
219 |
+
nrm = corr.new_full((n_parents+1,3,3), 1e-8)
|
220 |
+
ones = nrm.new_ones((len(corr),1,1))
|
221 |
+
ex = 1
|
222 |
+
if weights is not None:
|
223 |
+
weights = weights.view(len(corr),1,1)
|
224 |
+
corr *= weights # apply weights to correlation maps without increasing memory footprint
|
225 |
+
ones *= weights
|
226 |
+
else:
|
227 |
+
assert corr._base is not None and corr._base.shape[0] == n_parents+1
|
228 |
+
corr._base[-1] = 0 # reset garbage layer
|
229 |
+
ex = 1 if level > 1 else 0
|
230 |
+
n_children = (H1+ex) * (W1+ex)
|
231 |
+
res = corr.new_zeros((n_children,)+corr.shape[-2:])
|
232 |
+
|
233 |
+
sl = lambda v: slice(0,-1 or None) if v < 0 else slice(1,None)
|
234 |
+
c = 0
|
235 |
+
for y in (-1, 1):
|
236 |
+
for x in (-1, 1):
|
237 |
+
src_layers = parents[:,c]; c+= 1
|
238 |
+
# we want to do: res += corr[src_layers] (for all children != -1)
|
239 |
+
# but we only have 'res.index_add_()' <==> res[tgt_layers] += corr
|
240 |
+
tgt_layers = inverse_mapping(src_layers, max_elem=len(corr), default=n_parents)[:-1]
|
241 |
+
|
242 |
+
if not reverse:
|
243 |
+
# All of corr's channels MUST be utilized. for level>1, this doesn't hold,
|
244 |
+
# so we'll send them to a garbage channel ==> res[n_parents]
|
245 |
+
sel = good_slice( tgt_layers < n_parents )
|
246 |
+
|
247 |
+
res[:,sl(-y),sl(-x)].index_add_(0, tgt_layers[sel], corr[sel,sl(y),sl(x)])
|
248 |
+
nrm[:,sl(-y),sl(-x)].index_add_(0, tgt_layers[sel], ones[sel].expand(-1,2,2))
|
249 |
+
else:
|
250 |
+
''' parent=199=11*17+12 @ (x=48, y=44) at level=1
|
251 |
+
|-- child=171 @ (x=46,y=42) at level0
|
252 |
+
|-- child=172 @ (x=50,y=42) at level0
|
253 |
+
|-- child=187 @ (x=46,y=46) at level0
|
254 |
+
|-- child=188 @ (x=50,y=46) at level0
|
255 |
+
'''
|
256 |
+
out = res[:,sl(y),sl(x)]
|
257 |
+
sel = tgt_layers[:n_children]
|
258 |
+
torch.maximum(out, corr._base[sel,sl(-y),sl(-x)], out=out)
|
259 |
+
|
260 |
+
if not reverse:
|
261 |
+
if weights is not None: corr /= weights.clamp(min=1e-12) # cancel weights
|
262 |
+
weights = norm_borders(res, nrm, norm=norm)[:-1]
|
263 |
+
res = res[:-1] # remove garbage channel
|
264 |
+
res = res.view(H1+ex, W1+ex, *res.shape[-2:])
|
265 |
+
return res if reverse else (res, weights)
|
266 |
+
|
267 |
+
def norm_borders( res, nrm, norm=0.9 ):
|
268 |
+
""" apply some border normalization, modulated by `norm`
|
269 |
+
- if norm=0: no normalization at all
|
270 |
+
- if norm=1: full normalization
|
271 |
+
Formula: nrm = k * (nrm/k)**p = k**(1-p) * nrm**p,
|
272 |
+
with k=nrm[:,1,1] and p=norm
|
273 |
+
"""
|
274 |
+
new_weights = nrm[...,1,1].clone()
|
275 |
+
nrm = (nrm[...,1:2,1:2] ** (1-norm)) * (nrm ** norm)
|
276 |
+
# assert not torch.isnan(nrm).any()
|
277 |
+
|
278 |
+
# normalize results on the borders
|
279 |
+
res[...,0 ,0 ] /= nrm[...,0 ,0 ]
|
280 |
+
res[...,0 ,1:-1] /= nrm[...,0 ,1:2]
|
281 |
+
res[...,0 , -1] /= nrm[...,0 ,2 ]
|
282 |
+
res[...,1:-1,0 ] /= nrm[...,1:2,0 ]
|
283 |
+
res[...,1:-1,1:-1] /= nrm[...,1:2,1:2]
|
284 |
+
res[...,1:-1, -1] /= nrm[...,1:2,2 ]
|
285 |
+
res[..., -1,0 ] /= nrm[...,2 ,0 ]
|
286 |
+
res[..., -1,1:-1] /= nrm[...,2 ,1:2]
|
287 |
+
res[..., -1, -1] /= nrm[...,2 ,2 ]
|
288 |
+
return new_weights
|
289 |
+
|
290 |
+
|
291 |
+
def inverse_mapping( map, max_elem=None, default=None):
|
292 |
+
""" given a mapping {i:j} we output {j:i}
|
293 |
+
(the mapping is a torch array)
|
294 |
+
"""
|
295 |
+
assert isinstance(map, torch.Tensor) and map.ndim == 1
|
296 |
+
if max_elem is None: max_elem = map.max()
|
297 |
+
if default is None:
|
298 |
+
index = torch.empty(max_elem+1, dtype=torch.int64, device=map.device) # same size as corr, last elem == garbage
|
299 |
+
else:
|
300 |
+
index = torch.full((max_elem+1,), default, dtype=torch.int64, device=map.device) # same size as corr, last elem == garbage
|
301 |
+
index[map] = torch.arange(len(map), device=map.device)
|
302 |
+
return index
|
303 |
+
|
304 |
+
|
305 |
+
def good_slice( nonzero ):
|
306 |
+
good = nonzero.nonzero().ravel()
|
307 |
+
return slice(good.min().item(), good.max().item()+1)
|
308 |
+
|
309 |
+
|
310 |
+
def max_unpool(upper, lower, exclude_border=True):
|
311 |
+
# re-compute max-pool indices
|
312 |
+
if exclude_border:
|
313 |
+
# apparently, we cannot unpool on the bottom and right borders in legacy code (local_argmax with ex=1)
|
314 |
+
_, pos = F.max_pool2d(lower[:,:,:-1,:-1], 3, padding=1, stride=2, return_indices=True, ceil_mode=True)
|
315 |
+
W1 = lower.shape[-1]
|
316 |
+
pos = (pos//(W1-1))*W1 + (pos%(W1-1)) # fix the shortening
|
317 |
+
else:
|
318 |
+
_, pos = F.max_pool2d(lower, 3, padding=1, stride=2, return_indices=True)
|
319 |
+
|
320 |
+
# because there are potential collisions between overlapping 3x3 cells,
|
321 |
+
# that pytorch does not handle, we unpool in 4 successive non-overlapping steps.
|
322 |
+
for i in range(2):
|
323 |
+
for j in range(2):
|
324 |
+
# stride=0 instead of 1 because pytorch does some size checking, this is a hack
|
325 |
+
tmp = F.max_unpool2d(upper[:,:,i::2,j::2], pos[:,:,i::2,j::2], kernel_size=3, padding=0, stride=4, output_size=lower.shape[-2:])
|
326 |
+
if i == j == 0:
|
327 |
+
res = tmp
|
328 |
+
else:
|
329 |
+
torch.maximum(res, tmp, out=res)
|
330 |
+
|
331 |
+
# add scores to existing lower correlation map
|
332 |
+
lower += res
|
333 |
+
return lower
|
334 |
+
|
335 |
+
|
336 |
+
def mgrid( shape, **kw ):
|
337 |
+
""" Returns in (x, y) order (contrary to numpy which is (y,x) """
|
338 |
+
if isinstance(shape, torch.Tensor): shape = shape.shape
|
339 |
+
res = torch.meshgrid(*[torch.arange(n, dtype=torch.float32, **kw) for n in shape], indexing='ij')
|
340 |
+
return torch.stack(res[::-1], dim=-1).view(-1,2)
|
341 |
+
|
342 |
+
|
343 |
+
def check_corres( corres, step, rot=None ):
|
344 |
+
H, W, two = corres.shape
|
345 |
+
assert two == 2
|
346 |
+
if isinstance(corres, np.ndarray):
|
347 |
+
corres = torch.from_numpy(corres)
|
348 |
+
if rot is not None:
|
349 |
+
corres = affmul(rot, corres)
|
350 |
+
gt = mgrid(corres.shape[:2]).view(H,W,2)
|
351 |
+
assert ((gt - corres // step).abs() <= 2).float().mean() > 0.99, bb()
|
352 |
+
|
353 |
+
|
354 |
+
def best_correspondences(corr):
|
355 |
+
""" All positions are returned as x1, y1, x2, y2
|
356 |
+
"""
|
357 |
+
if isinstance(corr, tuple): return corr # for legacy
|
358 |
+
H1, W1, H2, W2 = corr.shape
|
359 |
+
fix1 = lambda arr: 4*arr+2 # center of cells in img1
|
360 |
+
div = lambda a,b: torch.div(a, b, rounding_mode='trunc') # because of warning in pytorch 1.9+
|
361 |
+
|
362 |
+
# best scores in img1
|
363 |
+
score1, pos1 = corr.view(H1, W1, H2*W2).max(dim=-1)
|
364 |
+
pos1 = torch.cat((fix1(mgrid(score1, device=pos1.device)), pos1.view(-1,1)%W2, div(pos1.view(-1,1),W2)), dim=-1)
|
365 |
+
|
366 |
+
# best scores in img2
|
367 |
+
score2, pos2 = max_pool3d( corr, kernel_size=4, stride=4 )
|
368 |
+
pos2, score2 = pos2.view(-1,1), score2.squeeze()
|
369 |
+
pos2 = torch.cat((fix1(div(pos2,W2*H2)%W1), fix1(div(pos2,(W1*H2*W2))), pos2%W2, div(pos2,W2)%H2), dim=-1).float()
|
370 |
+
|
371 |
+
return (pos1, score1), (pos2, score2)
|
372 |
+
|
373 |
+
|
374 |
+
def intersection( set1_, set2_ ):
|
375 |
+
""" Returns the indices of values in set1 that are duplicated in set2
|
376 |
+
"""
|
377 |
+
set1, map1 = set1_.squeeze().unique(return_inverse=True) # map1: i1 -> j1
|
378 |
+
set2 = set2_.squeeze().unique()
|
379 |
+
combined = torch.cat((set1, set2))
|
380 |
+
|
381 |
+
uniques, inverse, counts = combined.unique(return_counts=True, return_inverse=True)
|
382 |
+
# j -> u, i -> j, j -> n
|
383 |
+
# we are interested only in (j -> i) for n > 1:
|
384 |
+
# assert counts.max() <= 2, 'there were non-unique values in either set1 or set2'+bb()
|
385 |
+
# intersected_values = uniques[counts > 1]
|
386 |
+
inverse1 = inverse_mapping(inverse[:len(set1)], max_elem=len(uniques)-1)
|
387 |
+
intersected_indices1 = inverse1[counts>1]
|
388 |
+
return inverse_mapping(map1, max_elem=len(set1)-1)[intersected_indices1]
|
389 |
+
|
390 |
+
|
391 |
+
def reciprocal(self, corres1, corres2 ):
|
392 |
+
pos1, score1 = corres1
|
393 |
+
pos2, score2 = corres2
|
394 |
+
(H1, W1), (H2, W2) = score1.shape, map(lambda i: 4*i+1, score2.shape)
|
395 |
+
|
396 |
+
to_int = pos1.new_tensor((W1*H2*W2, H2*W2, W2, 1), dtype=torch.float32)
|
397 |
+
inter1 = intersection(pos1@to_int, pos2@to_int)
|
398 |
+
res = torch.cat((pos1[inter1], score1.view(-1,1)[inter1], 0*score1.view(-1,1)[inter1]), dim=-1)
|
399 |
+
return res
|
400 |
+
|
401 |
+
|
402 |
+
def max_pool3d( corr, kernel_size=4, stride=4 ):
|
403 |
+
H1, W1, H2, W2 = corr.shape
|
404 |
+
ks, st = kernel_size, stride
|
405 |
+
if corr.numel() >= 2**31 and corr.device != torch.device('cpu'):
|
406 |
+
# re-implementation due to a bug in pytorch
|
407 |
+
import core.cuda_deepm as kernels
|
408 |
+
return kernels.max_pool3d( corr.view(1, H1*W1, H2, W2), kernel_size, stride)
|
409 |
+
else:
|
410 |
+
return F.max_pool3d( corr.view(1, 1, H1*W1, H2, W2), kernel_size=(H1*W1,ks,ks), stride=(1,st,st), return_indices=True)
|
411 |
+
|
412 |
+
|
413 |
+
def forward_cuda(self, level, lower, weights=None, pooled=False):
|
414 |
+
import core.cuda_deepm as kernels # must be imported after torch_set_gpu()
|
415 |
+
assert lower.numel() < 2**31, 'please use cuda-lowmem, pytorch cannot handle big tensors'
|
416 |
+
pooled = lower if pooled else F.max_pool2d(lower, 3, padding=1, stride=2)
|
417 |
+
return kernels.forward_agg(level, self.border_inv, pooled, weights)
|
418 |
+
|
419 |
+
def forward_cuda_lowmem(self, level, lower, weights=None):
|
420 |
+
import core.cuda_deepm as kernels # must be imported after torch_set_gpu()
|
421 |
+
return kernels.forward_pool_agg(level, self.border_inv, lower, weights)
|
422 |
+
|
423 |
+
def backward_cuda(self, level, pyramid):
|
424 |
+
import core.cuda_deepm as kernels # must be imported after torch_set_gpu()
|
425 |
+
kernels.backward_agg_unpool(level, pyramid[level], pyramid[level-1], True)
|
426 |
+
# assert not torch.isnan(pyramid[level-1]).any(), bb()
|
427 |
+
return pyramid[level-1]
|
428 |
+
|
429 |
+
def merge_corres(self, corres, rots, all_corres, code):
|
430 |
+
" rot : reference --> rotated "
|
431 |
+
all_step = self.matcher.pixel_desc.get_atomic_patch_size() // 2 # step size in all_corres
|
432 |
+
dev = all_corres[0][1].device
|
433 |
+
|
434 |
+
# stack correspondences
|
435 |
+
corres = [torch.cat((p.view(*s.shape,4),s[:,:,None],torch.full_like(s[:,:,None],code)),dim=2) for (p,s) in corres]
|
436 |
+
|
437 |
+
import core.cuda_deepm as kernels # must be imported after torch_set_gpu()
|
438 |
+
kernels.merge_corres_one_side( corres[0].to(dev), 0, rots[0].to(dev), all_corres[0][1], all_step )
|
439 |
+
kernels.merge_corres_one_side( corres[1].to(dev), 2, rots[1].to(dev), all_corres[1][1], all_step )
|
440 |
+
|
core/losses/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from .multiloss import MultiLoss
|
6 |
+
from .pixel_ap_loss import PixelAPLoss
|
7 |
+
from .ap_loss_sampler import NghSampler
|
8 |
+
from .unsupervised_deepmatching_loss import DeepMatchingLoss
|
core/losses/ap_loss.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
class APLoss (nn.Module):
|
11 |
+
""" differentiable AP loss, through quantization.
|
12 |
+
|
13 |
+
Input: (N, M) values in [min, max]
|
14 |
+
label: (N, M) values in {0, 1}
|
15 |
+
|
16 |
+
Returns: list of query AP (for each n in {1..N})
|
17 |
+
Note: typically, you want to minimize 1 - mean(AP)
|
18 |
+
"""
|
19 |
+
def __init__(self, nq=25, min=0, max=1, euc=False):
|
20 |
+
nn.Module.__init__(self)
|
21 |
+
assert isinstance(nq, int) and 2 <= nq <= 100
|
22 |
+
self.nq = nq
|
23 |
+
self.min = min
|
24 |
+
self.max = max
|
25 |
+
self.euc = euc
|
26 |
+
gap = max - min
|
27 |
+
assert gap > 0
|
28 |
+
|
29 |
+
# init quantizer = non-learnable (fixed) convolution
|
30 |
+
self.quantizer = q = nn.Conv1d(1, 2*nq, kernel_size=1, bias=True).requires_grad_(False)
|
31 |
+
a = (nq-1) / gap
|
32 |
+
#1st half = lines passing to (min+x,1) and (min+x+1/a,0) with x = {nq-1..0}*gap/(nq-1)
|
33 |
+
q.weight.data[:nq] = -a
|
34 |
+
q.bias.data[:nq] = a*min + torch.arange(nq, 0, -1) # b = 1 + a*(min+x)
|
35 |
+
#2nd half = lines passing to (min+x,1) and (min+x-1/a,0) with x = {nq-1..0}*gap/(nq-1)
|
36 |
+
q.weight.data[nq:] = a
|
37 |
+
q.bias.data[nq:] = torch.arange(2-nq, 2, 1) - a*min # b = 1 - a*(min+x)
|
38 |
+
# first and last one are special: just horizontal straight line
|
39 |
+
q.weight.data[0] = q.weight.data[-1] = 0
|
40 |
+
q.bias.data[0] = q.bias.data[-1] = 1
|
41 |
+
|
42 |
+
def compute_AP(self, x, label):
|
43 |
+
N, M = x.shape
|
44 |
+
if self.euc: # euclidean distance in same range than similarities
|
45 |
+
x = 1 - torch.sqrt(2.001 - 2*x)
|
46 |
+
|
47 |
+
# quantize all predictions
|
48 |
+
q = self.quantizer(x.unsqueeze(1))
|
49 |
+
q = torch.min(q[:,:self.nq], q[:,self.nq:]).clamp(min=0) # N x Q x M
|
50 |
+
|
51 |
+
nbs = q.sum(dim=-1) # number of samples N x Q = c
|
52 |
+
rec = (q * label.view(N,1,M).float()).sum(dim=-1) # nb of correct samples = c+ N x Q
|
53 |
+
prec = rec.cumsum(dim=-1) / (1e-16 + nbs.cumsum(dim=-1)) # precision
|
54 |
+
rec /= rec.sum(dim=-1).unsqueeze(1) # norm in [0,1]
|
55 |
+
|
56 |
+
ap = (prec * rec).sum(dim=-1) # per-image AP
|
57 |
+
return ap
|
58 |
+
|
59 |
+
def forward(self, x, label):
|
60 |
+
assert x.shape == label.shape # N x M
|
61 |
+
return self.compute_AP(x, label)
|
core/losses/ap_loss_sampler.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
class NghSampler (nn.Module):
|
14 |
+
""" Given dense feature maps and pixel-dense flow,
|
15 |
+
compute a subset of all correspondences and return their scores and labels.
|
16 |
+
|
17 |
+
Distance to GT => 0 ... pos_d ... neg_d ... ngh
|
18 |
+
Pixel label => + + + + + + 0 0 - - - - - - -
|
19 |
+
|
20 |
+
Subsample on query side: if > 0, regular grid
|
21 |
+
< 0, random points
|
22 |
+
In both cases, the number of query points is = W*H/subq**2
|
23 |
+
"""
|
24 |
+
def __init__(self, ngh, subq=-8, subd=1, pos_d=2, neg_d=4, border=16, subd_neg=-8):
|
25 |
+
nn.Module.__init__(self)
|
26 |
+
assert 0 <= pos_d < neg_d <= (ngh if ngh else 99)
|
27 |
+
self.ngh = ngh
|
28 |
+
self.pos_d = pos_d
|
29 |
+
self.neg_d = neg_d
|
30 |
+
assert subd <= ngh or ngh == 0
|
31 |
+
assert subq != 0
|
32 |
+
self.sub_q = subq
|
33 |
+
self.sub_d = subd
|
34 |
+
self.sub_d_neg = subd_neg
|
35 |
+
if border is None: border = ngh
|
36 |
+
assert border >= ngh, 'border has to be larger than ngh'
|
37 |
+
self.border = border
|
38 |
+
self.precompute_offsets()
|
39 |
+
|
40 |
+
def precompute_offsets(self):
|
41 |
+
pos_d2 = self.pos_d**2
|
42 |
+
neg_d2 = self.neg_d**2
|
43 |
+
rad2 = self.ngh**2
|
44 |
+
rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple
|
45 |
+
pos = []
|
46 |
+
neg = []
|
47 |
+
for j in range(-rad, rad+1, self.sub_d):
|
48 |
+
for i in range(-rad, rad+1, self.sub_d):
|
49 |
+
d2 = i*i + j*j
|
50 |
+
if d2 <= pos_d2:
|
51 |
+
pos.append( (i,j) )
|
52 |
+
elif neg_d2 <= d2 <= rad2:
|
53 |
+
neg.append( (i,j) )
|
54 |
+
|
55 |
+
self.register_buffer('pos_offsets', torch.LongTensor(pos).view(-1,2).t())
|
56 |
+
self.register_buffer('neg_offsets', torch.LongTensor(neg).view(-1,2).t())
|
57 |
+
|
58 |
+
def gen_grid(self, step, aflow):
|
59 |
+
B, two, H, W = aflow.shape
|
60 |
+
dev = aflow.device
|
61 |
+
b1 = torch.arange(B, device=dev)
|
62 |
+
if step > 0:
|
63 |
+
# regular grid
|
64 |
+
x1 = torch.arange(self.border, W-self.border, step, device=dev)
|
65 |
+
y1 = torch.arange(self.border, H-self.border, step, device=dev)
|
66 |
+
H1, W1 = len(y1), len(x1)
|
67 |
+
shape = (B, H1, W1)
|
68 |
+
x1 = x1[None,None,:].expand(B,H1,W1).reshape(-1)
|
69 |
+
y1 = y1[None,:,None].expand(B,H1,W1).reshape(-1)
|
70 |
+
b1 = b1[:,None,None].expand(B,H1,W1).reshape(-1)
|
71 |
+
else:
|
72 |
+
# randomly spread
|
73 |
+
n = (H - 2*self.border) * (W - 2*self.border) // step**2
|
74 |
+
x1 = torch.randint(self.border, W-self.border, (n,), device=dev)
|
75 |
+
y1 = torch.randint(self.border, H-self.border, (n,), device=dev)
|
76 |
+
x1 = x1[None,:].expand(B,n).reshape(-1)
|
77 |
+
y1 = y1[None,:].expand(B,n).reshape(-1)
|
78 |
+
b1 = b1[:,None].expand(B,n).reshape(-1)
|
79 |
+
shape = (B, n)
|
80 |
+
return b1, y1, x1, shape
|
81 |
+
|
82 |
+
def forward(self, feats, confs, aflow, **kw):
|
83 |
+
B, two, H, W = aflow.shape
|
84 |
+
assert two == 2, bb()
|
85 |
+
feat1, conf1 = feats[0], (confs[0] if confs else None)
|
86 |
+
feat2, conf2 = feats[1], (confs[1] if confs else None)
|
87 |
+
|
88 |
+
# positions in the first image
|
89 |
+
b_, y1, x1, shape = self.gen_grid(self.sub_q, aflow)
|
90 |
+
|
91 |
+
# sample features from first image
|
92 |
+
feat1 = feat1[b_, :, y1, x1]
|
93 |
+
qconf = conf1[b_, :, y1, x1].view(shape) if confs else None
|
94 |
+
|
95 |
+
#sample GT from second image
|
96 |
+
xy2 = (aflow[b_, :, y1, x1] + 0.5).long().t()
|
97 |
+
mask = (0 <= xy2[0]) * (0 <= xy2[1]) * (xy2[0] < W) * (xy2[1] < H)
|
98 |
+
mask = mask.view(shape)
|
99 |
+
|
100 |
+
def clamp(xy):
|
101 |
+
torch.clamp(xy[0], 0, W-1, out=xy[0])
|
102 |
+
torch.clamp(xy[1], 0, H-1, out=xy[1])
|
103 |
+
return xy
|
104 |
+
|
105 |
+
# compute positive scores
|
106 |
+
xy2p = clamp(xy2[:,None,:] + self.pos_offsets[:,:,None])
|
107 |
+
pscores = torch.einsum('nk,ink->ni', feat1, feat2[b_, :, xy2p[1], xy2p[0]])
|
108 |
+
|
109 |
+
# compute negative scores
|
110 |
+
xy2n = clamp(xy2[:,None,:] + self.neg_offsets[:,:,None])
|
111 |
+
nscores = torch.einsum('nk,ink->ni', feat1, feat2[b_, :, xy2n[1], xy2n[0]])
|
112 |
+
|
113 |
+
if self.sub_d_neg:
|
114 |
+
# add distractors from a grid
|
115 |
+
b3, y3, x3 = self.gen_grid(self.sub_d_neg, aflow)[:3]
|
116 |
+
distractors = feat2[b3, :, y3, x3]
|
117 |
+
dscores = torch.einsum('nk,ik->ni', feat1, distractors)
|
118 |
+
del distractors
|
119 |
+
|
120 |
+
# remove scores that corresponds to positives or nulls
|
121 |
+
x2, y2 = xy2 = xy2.float()
|
122 |
+
xy3 = torch.stack((x3,y3)).float()
|
123 |
+
dis2 = torch.cdist((xy2+b_*512).T, (xy3+b3*512).T, compute_mode='donot_use_mm_for_euclid_dist')
|
124 |
+
dscores[dis2 < self.neg_d] = 0
|
125 |
+
|
126 |
+
scores = torch.cat((pscores, nscores, dscores), dim=1)
|
127 |
+
|
128 |
+
gt = scores.new_zeros(scores.shape, dtype=torch.uint8)
|
129 |
+
gt[:, :pscores.shape[1]] = 1
|
130 |
+
|
131 |
+
return scores, gt, mask, qconf
|
core/losses/multiloss.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from tools.trainer import backward
|
12 |
+
|
13 |
+
|
14 |
+
class MultiLoss (nn.Module):
|
15 |
+
""" This functions handles both supervised and unsupervised samples.
|
16 |
+
"""
|
17 |
+
def __init__(self, loss_sup, loss_unsup, alpha=0.3, inner_bw=True):
|
18 |
+
super().__init__()
|
19 |
+
assert 0 <= alpha
|
20 |
+
self.alpha_sup = 1 # coef of self-supervised loss
|
21 |
+
self.loss_sup = loss_sup
|
22 |
+
|
23 |
+
self.alpha_unsup = alpha # coef of unsupervised loss
|
24 |
+
self.loss_unsup = loss_unsup
|
25 |
+
|
26 |
+
self.inner_bw = inner_bw
|
27 |
+
|
28 |
+
def forward(self, desc1, desc2, homography, **kw):
|
29 |
+
sl_sup, sl_unsup = split_batch_sup_unsup(homography, 512 if self.inner_bw else 8)
|
30 |
+
|
31 |
+
inner_bw = self.inner_bw and self.training and torch.is_grad_enabled()
|
32 |
+
if inner_bw: (desc1, desc1_), (desc2, desc2_) = pause_gradient((desc1,desc2))
|
33 |
+
kw['desc1'], kw['desc2'], kw['homography'] = desc1, desc2, homography
|
34 |
+
|
35 |
+
(sup_name, sup_loss) ,= self.loss_sup(backward_loss=inner_bw*self.alpha_sup, **{k:v[sl_sup] for k,v in kw.items()}).items()
|
36 |
+
if inner_bw and sup_loss: sup_loss = backward(sup_loss) # backward to desc1 and desc2
|
37 |
+
|
38 |
+
(uns_name, uns_loss) ,= self.loss_unsup(**{k:v[sl_unsup] for k,v in kw.items()}).items()
|
39 |
+
uns_loss = self.alpha_unsup * uns_loss
|
40 |
+
if inner_bw and uns_loss: uns_loss = backward(uns_loss) # backward to desc1 and desc2
|
41 |
+
|
42 |
+
loss = sup_loss + uns_loss
|
43 |
+
return {'loss':(loss, [(desc1_,desc1.grad),(desc2_,desc2.grad)]), sup_name:float(sup_loss), uns_name:float(uns_loss)}
|
44 |
+
|
45 |
+
|
46 |
+
def pause_gradient( objs ):
|
47 |
+
return [(obj.detach().requires_grad_(True), obj) for obj in objs]
|
48 |
+
|
49 |
+
|
50 |
+
def split_batch_sup_unsup(homography, max_sup=512):
|
51 |
+
# split batch in supervised / unsupervised
|
52 |
+
i = int(torch.isfinite(homography[:,0,0]).sum()) # first ocurence
|
53 |
+
sl_sup, sl_unsup = slice(0, min(i,max_sup)), slice(i, None)
|
54 |
+
|
55 |
+
assert torch.isfinite(homography[sl_sup]).all(), 'batch is not properly sorted!'
|
56 |
+
assert torch.isnan(homography[sl_unsup]).all(), 'batch is not properly sorted!'
|
57 |
+
return sl_sup, sl_unsup
|
core/losses/pixel_ap_loss.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from .ap_loss import APLoss
|
11 |
+
from datasets.utils import applyh
|
12 |
+
|
13 |
+
|
14 |
+
class PixelAPLoss (nn.Module):
|
15 |
+
""" Computes the pixel-wise AP loss:
|
16 |
+
Given two images and ground-truth optical flow, computes the AP per pixel.
|
17 |
+
|
18 |
+
feat1: (B, C, H, W) pixel-wise features extracted from img1
|
19 |
+
feat2: (B, C, H, W) pixel-wise features extracted from img2
|
20 |
+
aflow: (B, 2, H, W) absolute flow: aflow[...,y1,x1] = x2,y2
|
21 |
+
"""
|
22 |
+
def __init__(self, sampler, nq=20, inner_bw=False, bw_step=256):
|
23 |
+
nn.Module.__init__(self)
|
24 |
+
self.aploss = APLoss(nq, min=0, max=1, euc=False)
|
25 |
+
self.name = 'pixAP'
|
26 |
+
self.sampler = sampler
|
27 |
+
self.inner_bw = inner_bw
|
28 |
+
self.bw_step = bw_step
|
29 |
+
|
30 |
+
def loss_from_ap(self, ap, rel):
|
31 |
+
return 1 - ap
|
32 |
+
|
33 |
+
def forward(self, desc1, desc2, homography, backward_loss=None, **kw):
|
34 |
+
if len(desc1) == 0: return dict(ap_loss=0)
|
35 |
+
aflow = aflow_from_H(homography, desc1)
|
36 |
+
descriptors = (desc1, desc2)
|
37 |
+
scores, gt, msk, qconf = self.sampler(descriptors, kw.get('reliability'), aflow)
|
38 |
+
|
39 |
+
# compute pixel-wise AP
|
40 |
+
n = msk.numel()
|
41 |
+
if n == 0: return 0
|
42 |
+
scores, gt = scores.view(n,-1), gt.view(n,-1)
|
43 |
+
|
44 |
+
backward_loss = backward_loss or self.inner_bw
|
45 |
+
if self.training and torch.is_grad_enabled() and backward_loss:
|
46 |
+
# progressive loss computation and backward, low memory but slow
|
47 |
+
scores_, qconf_ = scores, qconf if qconf is not None else scores.new_ones(msk.shape)
|
48 |
+
scores = scores.detach().requires_grad_(True)
|
49 |
+
qconf = qconf_.detach().requires_grad_(True)
|
50 |
+
msk = msk.ravel()
|
51 |
+
|
52 |
+
loss = 0
|
53 |
+
for i in range(0, n, self.bw_step):
|
54 |
+
sl = slice(i, i+self.bw_step)
|
55 |
+
ap = self.aploss(scores[sl], gt[sl])
|
56 |
+
pixel_loss = self.loss_from_ap(ap, qconf.ravel()[sl] if qconf is not None else None)
|
57 |
+
l = backward_loss / msk.sum() * pixel_loss[msk[sl]].sum()
|
58 |
+
loss += float(l)
|
59 |
+
l.backward() # cumulate gradient
|
60 |
+
loss = (loss, [(scores_,scores.grad)])
|
61 |
+
if qconf_.requires_grad: loss[1].append((qconf_,qconf.grad))
|
62 |
+
|
63 |
+
else:
|
64 |
+
ap = self.aploss(scores, gt).view(msk.shape)
|
65 |
+
pixel_loss = self.loss_from_ap(ap, qconf)
|
66 |
+
loss = pixel_loss[msk].mean()
|
67 |
+
|
68 |
+
return dict(ap_loss=loss)
|
69 |
+
|
70 |
+
|
71 |
+
def make_grid(B, H, W, device ):
|
72 |
+
b = torch.arange(B, device=device).view(B,1,1).expand(B,H,W)
|
73 |
+
y = torch.arange(H, device=device).view(1,H,1).expand(B,H,W)
|
74 |
+
x = torch.arange(W, device=device).view(1,1,W).expand(B,H,W)
|
75 |
+
return b.view(B,H*W), torch.stack((x,y),dim=-1).view(B,H*W,2)
|
76 |
+
|
77 |
+
|
78 |
+
def aflow_from_H( H_1to2, feat1 ):
|
79 |
+
B, _, H, W = feat1.shape
|
80 |
+
b, pos1 = make_grid(B,H,W, feat1.device)
|
81 |
+
pos2 = applyh(H_1to2, pos1.float())
|
82 |
+
return pos2.view(B,H,W,2).permute(0,3,1,2)
|
core/losses/unsupervised_deepmatching_loss.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from core import functional as myF
|
12 |
+
|
13 |
+
|
14 |
+
class DeepMatchingLoss (nn.Module):
|
15 |
+
""" This loss is based on DeepMatching (IJCV'16).
|
16 |
+
atleast: (int) minimum image size at which the pyramid construction stops.
|
17 |
+
sub: (int) prior subsampling
|
18 |
+
way: (str) which way to compute the asymmetric matching ('1', '2' or '12')
|
19 |
+
border: (int) ignore pixels too close to the border
|
20 |
+
rectify_p: (float) non-linear power-rectification in DeepMatching
|
21 |
+
eps: (float) epsilon for the L1 normalization. Kinda handles unmatched pixels.
|
22 |
+
"""
|
23 |
+
def __init__(self, eps=0.03, atleast=5, sub=2, way='12', border=16, rectify_p=1.5):
|
24 |
+
super().__init__()
|
25 |
+
assert way in ('1','2','12')
|
26 |
+
self.subsample = sub
|
27 |
+
self.border = border
|
28 |
+
self.way = way
|
29 |
+
self.atleast = atleast
|
30 |
+
self.rectify_p = rectify_p
|
31 |
+
self.eps = eps
|
32 |
+
|
33 |
+
self._cache = {}
|
34 |
+
|
35 |
+
def rectify(self, corr):
|
36 |
+
corr = corr.clip_(min=0)
|
37 |
+
corr = corr ** self.rectify_p
|
38 |
+
return corr
|
39 |
+
|
40 |
+
def forward(self, desc1, desc2, **kw):
|
41 |
+
# 1 --> 2
|
42 |
+
loss1 = self.forward_oneway(desc1, desc2, **kw) \
|
43 |
+
if '1' in self.way else 0
|
44 |
+
|
45 |
+
# 2 --> 1
|
46 |
+
loss2 = self.forward_oneway(desc2, desc1, **kw) \
|
47 |
+
if '2' in self.way else 0
|
48 |
+
|
49 |
+
return dict(deepm_loss=(loss1+loss2)/len(self.way))
|
50 |
+
|
51 |
+
def forward_oneway(self, desc1, desc2, dbg=(), **kw):
|
52 |
+
assert desc1.shape[:2] == desc2.shape[:2]
|
53 |
+
|
54 |
+
# prior subsampling
|
55 |
+
s = slice(self.border, -self.border or None, self.subsample)
|
56 |
+
desc1, desc2 = desc1[...,s,s], desc2[...,s,s]
|
57 |
+
desc1 = desc1[:,:,2::4,2::4] # subsample patches in 1st image
|
58 |
+
B, D, H1, W1, H2, W2 = desc1.shape + desc2.shape[-2:]
|
59 |
+
if B == 0: return 0 # empty batch
|
60 |
+
|
61 |
+
# intial 4D correlation volume
|
62 |
+
corr = torch.bmm(desc1.reshape(B,D,-1).transpose(1,2), desc2.reshape(B,D,-1)).view(B,H1,W1,H2,W2)
|
63 |
+
|
64 |
+
# build pyramid
|
65 |
+
pyramid = self.deep_matching(corr)
|
66 |
+
corr = pyramid[-1] # high-level correlation
|
67 |
+
corr = self.rectify(corr)
|
68 |
+
|
69 |
+
# L1 norm
|
70 |
+
B, H1, W1, H2, W2 = corr.shape
|
71 |
+
corr = corr / (corr.reshape(B,H1*W1,-1).sum(dim=-1).view(B,H1,W1,1,1) + self.eps)
|
72 |
+
|
73 |
+
# squared L2 norm
|
74 |
+
loss = - torch.square(corr).sum() / (B*H1*W1)
|
75 |
+
return loss
|
76 |
+
|
77 |
+
def deep_matching(self, corr):
|
78 |
+
# print(f'level=0 {corr.shape=}')
|
79 |
+
weights = None
|
80 |
+
pyramid = [corr]
|
81 |
+
for level in range(1,999):
|
82 |
+
corr, weights = self.forward_level(level, corr, weights)
|
83 |
+
pyramid.append(corr)
|
84 |
+
# print(f'{level=} {corr.shape=}')
|
85 |
+
if weights.sum() == 0: break # img1 has become too small
|
86 |
+
if min(corr.shape[-2:]) < 2*self.atleast: break # img2 has become too small
|
87 |
+
return pyramid
|
88 |
+
|
89 |
+
def forward_level(self, level, corr, weights):
|
90 |
+
B, H1, W1, H2, W2 = corr.shape
|
91 |
+
|
92 |
+
# max-pooling
|
93 |
+
pooled = F.max_pool2d(corr.view(B,H1*W1,H2,W2), 3, padding=1, stride=2)
|
94 |
+
pooled = pooled.view(B, H1, W1, *pooled.shape[-2:])
|
95 |
+
|
96 |
+
# print(f'rectifying corr at {level=}')
|
97 |
+
pooled = self.rectify(pooled)
|
98 |
+
|
99 |
+
# sparse conv
|
100 |
+
key = level, H1, W1, H2, W2
|
101 |
+
if key not in self._cache:
|
102 |
+
B, H1, W1, H2, W2 = myF.true_corr_shape(pooled.shape, level-1)
|
103 |
+
self._cache[key] = myF.children(level, H1, W1, H2, W2).to(corr.device)
|
104 |
+
|
105 |
+
return sparse_conv(level, pooled, self._cache[key], weights)
|
106 |
+
|
107 |
+
|
108 |
+
def sparse_conv(level, corr, parents, weights=None, border_norm=0.9):
|
109 |
+
B, H1, W1, H2, W2 = myF.true_corr_shape(corr.shape, level-1)
|
110 |
+
n_cache = len(parents)
|
111 |
+
|
112 |
+
# perform the sparse convolution 'manually'
|
113 |
+
# since sparse convolutions are not implemented in pytorch currently
|
114 |
+
corr = corr.view(B, -1, H2, W2)
|
115 |
+
|
116 |
+
res = corr.new_zeros((B, n_cache+1, H2, W2)) # last one = garbage channel
|
117 |
+
nrm = corr.new_full((n_cache+1, 3, 3), torch.finfo(corr.dtype).eps)
|
118 |
+
ones = nrm.new_ones((corr.shape[1], 1, 1))
|
119 |
+
ex = 1
|
120 |
+
if weights is not None:
|
121 |
+
weights = weights.view(corr.shape[1],1,1)
|
122 |
+
corr = corr * weights[None] # apply weights to correlation maps beforehand
|
123 |
+
ones *= weights
|
124 |
+
|
125 |
+
sl = lambda v: slice(0,-1 or None) if v < 0 else slice(1,None)
|
126 |
+
c = 0
|
127 |
+
for y in (-1, 1):
|
128 |
+
for x in (-1, 1):
|
129 |
+
src_layers = parents[:,c]; c+= 1
|
130 |
+
# we want to do: res += corr[src_layers] (for all children != -1)
|
131 |
+
# but we only have 'res.index_add_()' <==> res[tgt_layers] += corr
|
132 |
+
tgt_layers = myF.inverse_mapping(src_layers, max_elem=corr.shape[1], default=n_cache)[:-1]
|
133 |
+
|
134 |
+
# All of corr's channels MUST be utilized. for level>1, this doesn't hold,
|
135 |
+
# so we'll send them to a garbage channel ==> res[n_cache]
|
136 |
+
sel = myF.good_slice( tgt_layers < n_cache )
|
137 |
+
|
138 |
+
res[:,:,sl(-y),sl(-x)].index_add_(1, tgt_layers[sel], corr[:,sel,sl(y),sl(x)])
|
139 |
+
nrm[ :,sl(-y),sl(-x)].index_add_(0, tgt_layers[sel], ones[sel].expand(-1,2,2))
|
140 |
+
|
141 |
+
# normalize borders
|
142 |
+
weights = myF.norm_borders(res, nrm, norm=border_norm)[:-1]
|
143 |
+
|
144 |
+
res = res[:,:-1] # remove garbage channel
|
145 |
+
return res.view(B, H1+ex, W1+ex, *res.shape[-2:]), weights
|
146 |
+
|
core/pixel_desc.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torchvision.transforms as tvf
|
11 |
+
|
12 |
+
from core.conv_mixer import ConvMixer
|
13 |
+
|
14 |
+
norm_RGB = tvf.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
15 |
+
|
16 |
+
|
17 |
+
class PixelDesc (nn.Module):
|
18 |
+
def __init__(self, path='models/PUMP_st.pt'):
|
19 |
+
super().__init__()
|
20 |
+
state_dict = torch.load( path, 'cpu' )
|
21 |
+
self.pixel_desc = ConvMixer(output_dim=128, hidden_dim=512, depth=7, patch_size=4, kernel_size=9).eval()
|
22 |
+
self.pixel_desc.load_state_dict(state_dict)
|
23 |
+
|
24 |
+
def configure(self, pipeline):
|
25 |
+
# hot-update of the default HOG-based pipeline
|
26 |
+
pipeline.__class__ = type(type(pipeline).__name__+'_Trained', (DescPipeline, type(pipeline)), {})
|
27 |
+
return self
|
28 |
+
|
29 |
+
def get_atomic_patch_size(self):
|
30 |
+
return 4
|
31 |
+
|
32 |
+
def forward(self, img, stride=1, offset=0):
|
33 |
+
if img.ndim == 3: img = img[None]
|
34 |
+
trf = torch.eye(3, device=img.device)
|
35 |
+
|
36 |
+
desc = self.pixel_desc( img )
|
37 |
+
desc = desc[..., offset::stride, offset::stride].contiguous() # free memory
|
38 |
+
return desc, trf
|
39 |
+
|
40 |
+
|
41 |
+
class DescPipeline:
|
42 |
+
def extract_descs(self, img1, img2, dtype=None):
|
43 |
+
# this will rotate the image if needed
|
44 |
+
img1, sca1 = self.demultiplex_img_trf(img1)
|
45 |
+
img2, sca2 = self.demultiplex_img_trf(img2)
|
46 |
+
|
47 |
+
# convert to float and normalize std
|
48 |
+
fimg1, fimg2 = [norm_RGB(img.type(dtype)/255) for img in (img1, img2)]
|
49 |
+
|
50 |
+
self.pixel_desc.type(fimg1.dtype)
|
51 |
+
desc1, trf1 = self.pixel_desc(fimg1, stride=4, offset=2)
|
52 |
+
desc2, trf2 = self.pixel_desc(fimg2)
|
53 |
+
return (img1, img2), (desc1.type(dtype), desc2.type(dtype)), (sca1@trf1, sca2@trf2)
|
54 |
+
|
55 |
+
def first_level(self, desc1, desc2, **kw):
|
56 |
+
B, C, H, W = desc1.shape
|
57 |
+
weights = desc1.permute(0, 2, 3, 1).view(H*W, C, 1, 1) # rearrange(desc1, '1 C H W -> (H W) C 1 1')
|
58 |
+
corr = F.conv2d(desc2, weights, padding=0, bias=None)[0]
|
59 |
+
norms = torch.ones(desc1.shape[-2:], device=corr.device)
|
60 |
+
return corr.view(desc1.shape[-2:]+desc2.shape[-2:]), norms
|
datasets/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from .image_set import *
|
6 |
+
from .web_images import RandomWebImages
|
7 |
+
from .pair_dataset import *
|
8 |
+
from .pair_loader import *
|
9 |
+
from .sfm120k import *
|
datasets/demo_warp/mountains_src.jpg
ADDED
![]() |
datasets/demo_warp/mountains_tgt.jpg
ADDED
![]() |
datasets/image_set.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
import os
|
7 |
+
from os.path import *
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
class ImageSet(object):
|
12 |
+
""" Base class for an image dataset.
|
13 |
+
"""
|
14 |
+
def __init__(self, root, imgs):
|
15 |
+
self.root = root
|
16 |
+
self.imgs = imgs
|
17 |
+
assert imgs, f'Empty image set in {root}'
|
18 |
+
|
19 |
+
def init_from_folder(self, *args, **kw):
|
20 |
+
imset = ImageSet.from_folder(*args, **kw)
|
21 |
+
ImageSet.__init__(self, imset.root, imset.imgs)
|
22 |
+
|
23 |
+
def __len__(self):
|
24 |
+
return len(self.imgs)
|
25 |
+
|
26 |
+
def get_image_path(self, idx):
|
27 |
+
return os.path.join(self.root, self.imgs[idx])
|
28 |
+
|
29 |
+
def get_image(self, idx):
|
30 |
+
fname = self.get_image_path(idx)
|
31 |
+
try:
|
32 |
+
return Image.open(fname).convert('RGB')
|
33 |
+
except Exception as e:
|
34 |
+
raise IOError("Could not load image %s (reason: %s)" % (fname, str(e)))
|
35 |
+
|
36 |
+
__getitem__ = get_image
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def from_folder(root, exts=('.jpg','.jpeg','.png','.ppm'), recursive=False, listing=False, check_imgs=False):
|
40 |
+
"""
|
41 |
+
recursive: bool or func. If a function, it must evaluate True to the directory name.
|
42 |
+
"""
|
43 |
+
if listing:
|
44 |
+
if listing is True: listing = f"list_imgs{'_recursive' if recursive else ''}.txt"
|
45 |
+
flist = join(root, listing)
|
46 |
+
try: return ImageSet.from_listing(root,flist)
|
47 |
+
except IOError: print(f'>> ImageSet.from_folder(listing=True): entering {root}...')
|
48 |
+
|
49 |
+
if check_imgs is True: # default verif function
|
50 |
+
check_imgs = verify_img
|
51 |
+
|
52 |
+
for _, dirnames, dirfiles in os.walk(root):
|
53 |
+
imgs = sorted([f for f in dirfiles if f.lower().endswith(exts)])
|
54 |
+
if check_imgs: imgs = [img for img in imgs if check_imgs(join(root,img))]
|
55 |
+
|
56 |
+
if recursive:
|
57 |
+
for dirname in sorted(dirnames):
|
58 |
+
if callable(recursive) and not recursive(join(root,dirname)): continue
|
59 |
+
imset = ImageSet.from_folder(join(root,dirname), exts=exts, recursive=recursive, listing=listing, check_imgs=check_imgs)
|
60 |
+
imgs += [join(dirname,f) for f in imset.imgs]
|
61 |
+
break # recursion is handled internally
|
62 |
+
|
63 |
+
if listing:
|
64 |
+
try: open(flist,'w').write('\n'.join(imgs))
|
65 |
+
except IOError: pass # write permission denied
|
66 |
+
return ImageSet(root, imgs)
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def from_listing(root, list_path):
|
70 |
+
return ImageSet(root, open(list_path).read().splitlines())
|
71 |
+
|
72 |
+
def circular_pad(self, min_size):
|
73 |
+
assert self.imgs, 'cannot pad an empty image set'
|
74 |
+
while len(self.imgs) < min_size:
|
75 |
+
self.imgs += self.imgs # artifically augment size
|
76 |
+
self.imgs = self.imgs[:min_size or None]
|
77 |
+
return self
|
78 |
+
|
79 |
+
def __repr__(self):
|
80 |
+
prefix = os.path.commonprefix((self.get_image_path(0),self.get_image_path(len(self)-1)))
|
81 |
+
return f'{self.__class__.__name__}({len(self)} images from {prefix}...)'
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
def verify_img(path, exts=None):
|
86 |
+
if exts and not path.lower().endswith(exts): return False
|
87 |
+
try:
|
88 |
+
Image.open(path).convert('RGB') # try to open it
|
89 |
+
return True
|
90 |
+
except:
|
91 |
+
return False
|
datasets/pair_dataset.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
import os, os.path as osp
|
7 |
+
from tqdm import tqdm
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .image_set import ImageSet
|
13 |
+
from .transforms import instanciate_transforms
|
14 |
+
from .utils import DatasetWithRng
|
15 |
+
invh = np.linalg.inv
|
16 |
+
|
17 |
+
|
18 |
+
class ImagePairs (DatasetWithRng):
|
19 |
+
""" Base class for a dataset that serves image pairs.
|
20 |
+
"""
|
21 |
+
imgs = None # regular image dataset
|
22 |
+
pairs = [] # list of (idx1, idx2), ...
|
23 |
+
|
24 |
+
def __init__(self, image_set, pairs, trf=None, **rng):
|
25 |
+
assert image_set and pairs, 'empty images or pairs'
|
26 |
+
super().__init__(**rng)
|
27 |
+
self.imgs = image_set
|
28 |
+
self.pairs = pairs
|
29 |
+
self.trf = instanciate_transforms(trf, rng=self.rng)
|
30 |
+
|
31 |
+
def __len__(self):
|
32 |
+
return len(self.pairs)
|
33 |
+
|
34 |
+
def __getitem__(self, idx):
|
35 |
+
transform = self.trf or (lambda x:x)
|
36 |
+
pair = tuple(map(transform, self._load_pair(idx)))
|
37 |
+
return pair, {}
|
38 |
+
|
39 |
+
def _load_pair(self, idx):
|
40 |
+
i,j = self.pairs[idx]
|
41 |
+
img1 = self.imgs.get_image(i)
|
42 |
+
return (img1, img1) if i == j else (img1, self.imgs.get_image(j))
|
43 |
+
|
44 |
+
def __repr__(self):
|
45 |
+
return f'{self.__class__.__name__}({len(self)} pairs from {self.imgs})'
|
46 |
+
|
47 |
+
|
48 |
+
class StillImagePairs (ImagePairs):
|
49 |
+
""" A dataset of 'still' image pairs used for debugging purposes.
|
50 |
+
"""
|
51 |
+
def __init__(self, image_set, pairs=None, **rng):
|
52 |
+
if isinstance(image_set, ImagePairs):
|
53 |
+
super().__init__(image_set.imgs, pairs or image_set.pairs, **rng)
|
54 |
+
else:
|
55 |
+
super().__init__(image_set, pairs or [(i,i) for i in range(len(image_set))], **rng)
|
56 |
+
|
57 |
+
def __getitem__(self, idx):
|
58 |
+
img1, img2 = self._load_pair(idx)
|
59 |
+
sx, sy = img2.size / np.float32(img1.size)
|
60 |
+
return (img1, img2), dict(homography=np.diag(np.float32([sx, sy, 1])))
|
61 |
+
|
62 |
+
|
63 |
+
class SyntheticImagePairs (StillImagePairs):
|
64 |
+
""" A synthetic generator of image pairs.
|
65 |
+
Given a normal image dataset, it constructs pairs using random homographies & noise.
|
66 |
+
|
67 |
+
scale: prior image scaling.
|
68 |
+
distort: distortion applied independently to (img1,img2) if sym=True else just img2
|
69 |
+
sym: (bool) see above.
|
70 |
+
"""
|
71 |
+
def __init__(self, image_set, scale='', distort='', sym=False, **rng):
|
72 |
+
super().__init__(image_set, **rng)
|
73 |
+
self.symmetric = sym
|
74 |
+
self.scale = instanciate_transforms(scale, rng=self.rng)
|
75 |
+
self.distort = instanciate_transforms(distort, rng=self.rng)
|
76 |
+
|
77 |
+
def __getitem__(self, idx):
|
78 |
+
(img1, img2), gt = super().__getitem__(idx)
|
79 |
+
|
80 |
+
img1 = dict(img=img1, homography=np.eye(3,dtype=np.float32))
|
81 |
+
if img1['img'] is img2:
|
82 |
+
img1 = self.scale(img1)
|
83 |
+
img2 = self.distort(dict(img1))
|
84 |
+
if self.symmetric: img1 = self.distort(img1)
|
85 |
+
else:
|
86 |
+
if self.symmetric: img1 = self.distort(self.scale(img1))
|
87 |
+
img2 = self.distort(self.scale(dict(img=img2, **gt)))
|
88 |
+
|
89 |
+
return (img1['img'], img2['img']), dict(homography=img2['homography'] @ invh(img1['homography']))
|
90 |
+
|
91 |
+
def __repr__(self):
|
92 |
+
format = lambda s: ','.join(l.strip() for l in repr(s).splitlines() if l).replace(',','',1)
|
93 |
+
return f"{self.__class__.__name__}({len(self)} images, scale={format(self.scale)}, distort={format(self.distort)})"
|
94 |
+
|
95 |
+
|
96 |
+
class CatImagePairs (DatasetWithRng):
|
97 |
+
""" Concatenation of several ImagePairs datasets
|
98 |
+
"""
|
99 |
+
def __init__(self, *pair_datasets, seed=torch.initial_seed()):
|
100 |
+
assert all(isinstance(db, ImagePairs) for db in pair_datasets)
|
101 |
+
self.pair_datasets = pair_datasets
|
102 |
+
DatasetWithRng.__init__(self, seed=seed) # init last
|
103 |
+
self._init()
|
104 |
+
|
105 |
+
def _init(self):
|
106 |
+
self._pair_offsets = np.cumsum([0] + [len(db) for db in self.pair_datasets])
|
107 |
+
self.npairs = self._pair_offsets[-1]
|
108 |
+
|
109 |
+
def __len__(self):
|
110 |
+
return self.npairs
|
111 |
+
|
112 |
+
def __repr__(self):
|
113 |
+
fmt_str = f"{type(self).__name__}({len(self)} pairs,"
|
114 |
+
for i,db in enumerate(self.pair_datasets):
|
115 |
+
npairs = self._pair_offsets[i+1] - self._pair_offsets[i]
|
116 |
+
fmt_str += f'\n\t{npairs} from '+str(db).replace("\n"," ") + ','
|
117 |
+
return fmt_str[:-1] + ')'
|
118 |
+
|
119 |
+
def __getitem__(self, idx):
|
120 |
+
b, i = self._which(idx)
|
121 |
+
return self.pair_datasets[b].__getitem__(i)
|
122 |
+
|
123 |
+
def _which(self, i):
|
124 |
+
pos = np.searchsorted(self._pair_offsets, i, side='right')-1
|
125 |
+
assert pos < self.npairs, 'Bad pair index %d >= %d' % (i, self.npairs)
|
126 |
+
return pos, i - self._pair_offsets[pos]
|
127 |
+
|
128 |
+
def _call(self, func, i, *args, **kwargs):
|
129 |
+
b, j = self._which(i)
|
130 |
+
return getattr(self.pair_datasets[b], func)(j, *args, **kwargs)
|
131 |
+
|
132 |
+
def init_worker(self, tid):
|
133 |
+
for db in self.pair_datasets:
|
134 |
+
db.init_worker(tid)
|
135 |
+
|
136 |
+
|
137 |
+
class BalancedCatImagePairs (CatImagePairs):
|
138 |
+
""" Balanced concatenation of several ImagePairs datasets
|
139 |
+
"""
|
140 |
+
def __init__(self, npairs=0, *pair_datasets, **kw):
|
141 |
+
assert isinstance(npairs, int) and npairs >= 0, 'BalancedCatImagePairs(npairs != int)'
|
142 |
+
assert len(pair_datasets) > 0, 'no dataset provided'
|
143 |
+
|
144 |
+
if len(pair_datasets) >= 3 and isinstance(pair_datasets[1], int):
|
145 |
+
assert len(pair_datasets) % 2 == 1
|
146 |
+
pair_datasets = [npairs] + list(pair_datasets)
|
147 |
+
npairs, pair_datasets = pair_datasets[0::2], pair_datasets[1::2]
|
148 |
+
assert all(isinstance(n, int) for n in npairs)
|
149 |
+
self._pair_offsets = np.cumsum([0]+npairs)
|
150 |
+
self.npairs = self._pair_offsets[-1]
|
151 |
+
else:
|
152 |
+
self.npairs = npairs or max(len(db) for db in pair_datasets)
|
153 |
+
self._pair_offsets = np.linspace(0, self.npairs, len(pair_datasets)+1).astype(int)
|
154 |
+
CatImagePairs.__init__(self, *pair_datasets, **kw)
|
155 |
+
|
156 |
+
def set_epoch(self, epoch):
|
157 |
+
DatasetWithRng.init_worker(self, epoch) # random seed only depends on the epoch
|
158 |
+
self._init() # reset permutations for this epoch
|
159 |
+
|
160 |
+
def init_worker(self, tid):
|
161 |
+
CatImagePairs.init_worker(self, tid)
|
162 |
+
|
163 |
+
def _init(self):
|
164 |
+
self._perms = []
|
165 |
+
for i,db in enumerate(self.pair_datasets):
|
166 |
+
assert len(db), 'cannot balance if there is an empty dataset'
|
167 |
+
avail = self._pair_offsets[i+1] - self._pair_offsets[i]
|
168 |
+
idxs = np.arange(len(db))
|
169 |
+
while len(idxs) < avail:
|
170 |
+
idxs = np.r_[idxs,idxs]
|
171 |
+
if self.seed: # if not seed, then no shuffle
|
172 |
+
self.rng.shuffle(idxs[(avail//len(db))*len(db):])
|
173 |
+
self._perms.append( idxs[:avail] )
|
174 |
+
# print(self._perms)
|
175 |
+
|
176 |
+
def _which(self, i):
|
177 |
+
pos, idx = super()._which(i)
|
178 |
+
return pos, self._perms[pos][idx]
|
179 |
+
|
180 |
+
|
181 |
+
class UnsupervisedPairs (ImagePairs):
|
182 |
+
""" Unsupervised image pairs obtained from SfM
|
183 |
+
"""
|
184 |
+
def __init__(self, img_set, pair_file_path):
|
185 |
+
assert isinstance(img_set, ImageSet), bb()
|
186 |
+
self.pair_list = self._parse_pair_list(pair_file_path)
|
187 |
+
self.corres_dir = osp.join(osp.split(pair_file_path)[0], 'corres')
|
188 |
+
|
189 |
+
tag_to_idx = {n:i for i,n in enumerate(img_set.imgs)}
|
190 |
+
img_indices = lambda pair: tuple([tag_to_idx[n] for n in pair])
|
191 |
+
super().__init__(img_set, [img_indices(pair) for pair in self.pair_list])
|
192 |
+
|
193 |
+
def __repr__(self):
|
194 |
+
return f"{type(self).__name__}({len(self)} pairs from {self.imgs})"
|
195 |
+
|
196 |
+
def _parse_pair_list(self, pair_file_path):
|
197 |
+
res = []
|
198 |
+
for row in open(pair_file_path).read().splitlines():
|
199 |
+
row = row.split()
|
200 |
+
if len(row) != 2: raise IOError()
|
201 |
+
res.append((row[0], row[1]))
|
202 |
+
return res
|
203 |
+
|
204 |
+
def get_corres_path(self, pair_idx):
|
205 |
+
img1, img2 = [osp.basename(self.imgs.imgs[i]) for i in self.pairs[pair_idx]]
|
206 |
+
return osp.join(self.corres_dir, f'{img1}_{img2}.npy')
|
207 |
+
|
208 |
+
def get_corres(self, pair_idx):
|
209 |
+
return np.load(self.get_corres_path(pair_idx))
|
210 |
+
|
211 |
+
def __getitem__(self, idx):
|
212 |
+
img1, img2 = self._load_pair(idx)
|
213 |
+
return (img1, img2), dict(corres=self.get_corres(idx))
|
214 |
+
|
215 |
+
|
216 |
+
if __name__ == '__main__':
|
217 |
+
from datasets import *
|
218 |
+
from tools.viz import show_random_pairs
|
219 |
+
|
220 |
+
db = BalancedCatImagePairs(
|
221 |
+
3125, SyntheticImagePairs(RandomWebImages(0,52),distort='RandomTilting(0.5)'),
|
222 |
+
4875, SyntheticImagePairs(SfM120k_Images(),distort='RandomTilting(0.5)'),
|
223 |
+
8000, SfM120k_Pairs())
|
224 |
+
|
225 |
+
show_random_pairs(db)
|
226 |
+
|
datasets/pair_loader.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from core import functional as myF
|
10 |
+
from tools.common import todevice
|
11 |
+
from .transforms import instanciate_transforms
|
12 |
+
from .utils import *
|
13 |
+
|
14 |
+
|
15 |
+
class FastPairLoader (DatasetWithRng):
|
16 |
+
""" On-the-fly generation of related image pairs
|
17 |
+
crop: random crop applied to both images
|
18 |
+
scale: random scaling applied to img2
|
19 |
+
distort: random ditorsion applied to img2
|
20 |
+
|
21 |
+
self[idx] returns: (img1, img2), dict(homography=)
|
22 |
+
(homography: 3x3 array, can be nan)
|
23 |
+
"""
|
24 |
+
def __init__(self, dataset, crop=256, transform='', p_flip=0, p_swap=0, scale_jitter=0, seed=None):
|
25 |
+
super().__init__(seed)
|
26 |
+
self.dataset = self.with_same_rng(dataset)
|
27 |
+
self.transform = instanciate_transforms( transform, rng=self.rng )
|
28 |
+
self.crop_size = crop
|
29 |
+
self.p_swap = p_swap
|
30 |
+
self.p_flip = p_flip
|
31 |
+
self.scale_jitter = abs(np.log1p(scale_jitter))
|
32 |
+
|
33 |
+
def __len__(self):
|
34 |
+
return len(self.dataset)
|
35 |
+
|
36 |
+
def __repr__(self):
|
37 |
+
fmt_str = f'FastPairLoader({self.dataset},\n'
|
38 |
+
short_repr = lambda s: repr(s).strip().replace('\n',', ')[14:-1].replace(' ',' ')
|
39 |
+
fmt_str += ' Transform:\t%s\n' % short_repr(self.transform)
|
40 |
+
fmt_str +=f' Crop={self.crop_size}, scale_jitter=x{np.exp(self.scale_jitter):g}, p_swap={self.p_swap:g}'
|
41 |
+
return fmt_str
|
42 |
+
|
43 |
+
def init_worker(self, tid):
|
44 |
+
super().init_worker(tid)
|
45 |
+
self.dataset.init_worker(tid)
|
46 |
+
|
47 |
+
def set_epoch(self, epoch):
|
48 |
+
self.dataset.set_epoch(epoch)
|
49 |
+
|
50 |
+
def __getitem__(self, idx):
|
51 |
+
self.init_worker(idx) # preserve RNG for this pair
|
52 |
+
(img1, img2), gt = self.dataset[idx]
|
53 |
+
|
54 |
+
if self.rng.random() < self.p_swap:
|
55 |
+
img1, img2 = img2, img1
|
56 |
+
if 'homography' in gt: gt['homography'] = invh(gt['homography'])
|
57 |
+
if 'corres' in gt: gt['corres'] = swap_corres(gt['corres'])
|
58 |
+
|
59 |
+
if self.rng.random() < self.p_flip:
|
60 |
+
img1, img2, gt = flip_image_pair(img1, img2, gt)
|
61 |
+
|
62 |
+
# apply transformations to the second image
|
63 |
+
img2 = self.transform(dict(img=img2))
|
64 |
+
|
65 |
+
homography, corres = spatial_relationship( img1, img2, gt )
|
66 |
+
|
67 |
+
# find a good window
|
68 |
+
img1, img2 = map(self._pad_rgb_numpy, (img1, img2['img']))
|
69 |
+
|
70 |
+
if not 'debug':
|
71 |
+
from tools.viz import show_correspondences
|
72 |
+
print(np.median(corres[:,5]))
|
73 |
+
show_correspondences(img1, img2, corres, bb=bb)
|
74 |
+
|
75 |
+
def windows_from_corres( idx, scale_jitter=1 ):
|
76 |
+
c = corres[idx]
|
77 |
+
p1, p2, scale = c[0:2], c[2:4], c[6]
|
78 |
+
scale *= scale_jitter
|
79 |
+
|
80 |
+
# make windows based on scaling
|
81 |
+
win1 = window(*p1, self.crop_size, max(1, 1/scale), img1.shape)
|
82 |
+
win2 = window(*p2, self.crop_size, max(1, scale/1), img2.shape)
|
83 |
+
return win1, win2
|
84 |
+
|
85 |
+
best = 0, None
|
86 |
+
for idx in self.rng.choice(len(corres), size=min(len(corres),5), replace=False):
|
87 |
+
# pick a correspondence at random
|
88 |
+
win1, win2 = windows_from_corres( idx )
|
89 |
+
|
90 |
+
# check how many matches are in the 2 windows
|
91 |
+
score = score_windows(is_in(corres[:,0:2],win1), is_in(corres[:,2:4],win2))
|
92 |
+
if score > best[0]: best = score, idx
|
93 |
+
|
94 |
+
others = {}
|
95 |
+
if None in best: # counldn't find a good window
|
96 |
+
img1 = img2 = np.zeros((self.crop_size,self.crop_size,3), dtype=np.uint8)
|
97 |
+
corres = np.empty((0, 6), dtype=np.float32)
|
98 |
+
else:
|
99 |
+
# jitter scales
|
100 |
+
scale_jitter = np.exp(self.rng.uniform(-self.scale_jitter, self.scale_jitter))
|
101 |
+
win1, win2 = windows_from_corres( best[1], scale_jitter )
|
102 |
+
# print(win1, win2, img1.shape, img2.shape)
|
103 |
+
img1, img2 = imresize(img1[win1], self.crop_size), imresize(img2[win2], self.crop_size)
|
104 |
+
trf1, trf2 = wintrf(win1, img1), wintrf(win2, img2)
|
105 |
+
|
106 |
+
# fix rotation if necessary
|
107 |
+
angle_scores = np.bincount(corres[:,5].astype(int) % 8)
|
108 |
+
rot90 = int((((angle_scores.argmax() + 4) % 8) - 4) / 2)
|
109 |
+
if rot90: # rectify rotation
|
110 |
+
img2, trf = myF.rotate_img_90((img2, np.eye(3)), 90*rot90)
|
111 |
+
trf2 = invh(trf) @ trf2
|
112 |
+
|
113 |
+
homography = trf2 @ homography @ invh(trf1)
|
114 |
+
corres = myF.affmul((trf1,trf2), corres)
|
115 |
+
|
116 |
+
f32c = lambda i,**kw: np.require(i, requirements='CWAE', **kw)
|
117 |
+
return (f32c(img1), f32c(img2)), dict(homography = f32c(homography, dtype=np.float32), corres=corres, **others)
|
118 |
+
|
119 |
+
def _pad_rgb_numpy(self, img):
|
120 |
+
if img.mode != 'RGB':
|
121 |
+
img = img.convert('RGB')
|
122 |
+
if min(img.size) < self.crop_size:
|
123 |
+
w, h = img.size
|
124 |
+
result = Image.new('RGB', (max(w,self.crop_size), max(h,self.crop_size)), 0)
|
125 |
+
result.paste(img, (0, 0))
|
126 |
+
img = result
|
127 |
+
return np.asarray(img)
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
def swap_corres( corres ): # swap img1 and img2
|
132 |
+
res = corres.copy()
|
133 |
+
res[:,[0,1,2,3]] = corres[:,[2,3,0,1]]
|
134 |
+
if corres.shape[1] > 4: # invert rotation and scale
|
135 |
+
scale, rot = myF.decode_scale_rot(corres[:,5])
|
136 |
+
res[:,5] = myF.encode_scale_rot(1/scale, -rot)
|
137 |
+
return res
|
138 |
+
|
139 |
+
def flip(img):
|
140 |
+
w, h = img.size
|
141 |
+
return img.transpose(Image.FLIP_LEFT_RIGHT), np.float32( [[-1,0,w-1],[0,1,0],[0,0,1]] )
|
142 |
+
|
143 |
+
def flip_image_pair(img1, img2, gt):
|
144 |
+
img1, F1 = flip(img1)
|
145 |
+
img2, F2 = flip(img2)
|
146 |
+
res = {}
|
147 |
+
for key, value in gt.items():
|
148 |
+
if key == 'homography':
|
149 |
+
res['homography'] = F2 @ value @ F1
|
150 |
+
elif key == 'aflow':
|
151 |
+
assert False, 'flip for aflow: todo'
|
152 |
+
elif key == 'corres':
|
153 |
+
new_corres = np.c_[applyh(F1,value[:,0:2]), applyh(F2,value[:,2:4])]
|
154 |
+
if value.shape[1] == 4: pass
|
155 |
+
elif value.shape[1] == 6:
|
156 |
+
scale, rot = myF.decode_scale_rot(value[:,5])
|
157 |
+
new_code = myF.encode_scale_rot(scale, -rot)
|
158 |
+
new_corres = np.c_[new_corres,value[:,4],new_code]
|
159 |
+
res['corres'] = new_corres
|
160 |
+
else:
|
161 |
+
raise ValueError(f"flip_image_pair: bad gt field '{key}'")
|
162 |
+
return img1, img2, res
|
163 |
+
|
164 |
+
|
165 |
+
def spatial_relationship( img1, img2, gt ):
|
166 |
+
if 'homography' in gt:
|
167 |
+
homography = gt['homography']
|
168 |
+
if 'homography' in img2:
|
169 |
+
homography = np.float32(img2['homography']) @ homography
|
170 |
+
corres = corres_from_homography(homography, *img1.size)
|
171 |
+
|
172 |
+
elif 'corres' in gt:
|
173 |
+
homography = np.full((3,3), np.nan, dtype=np.float32)
|
174 |
+
corres = gt['corres']
|
175 |
+
if 'homography' in img2:
|
176 |
+
corres[:,2:4] = applyh(img2['homography'], corres[:,2:4])
|
177 |
+
else:
|
178 |
+
img2['homography'] = np.eye(3)
|
179 |
+
scales = np.sqrt(np.abs(np.linalg.det(jacobianh(img2['homography'], corres[:,0:2]).T)))
|
180 |
+
|
181 |
+
if corres.shape[1] == 4:
|
182 |
+
scales, rots = scale_rot_from_corres(corres)
|
183 |
+
corres = np.c_[corres, np.ones_like(scales), myF.encode_scale_rot(scales,rots*180/np.pi), scales]
|
184 |
+
elif corres.shape[1] == 6:
|
185 |
+
corres = np.c_[corres, scales * myF.decode_scale_rot(corres[:,5])[0]]
|
186 |
+
else:
|
187 |
+
assert ValueError(f'bad shape for corres: {corres.shape}')
|
188 |
+
|
189 |
+
return homography, corres
|
190 |
+
|
191 |
+
|
192 |
+
def scale_rot_from_corres( corres, sub=256, nn=16 ):
|
193 |
+
# select a subset of relevant correspondences
|
194 |
+
sub = np.random.choice(len(corres), size=min(len(corres),sub), replace=False)
|
195 |
+
sub = corres[sub]
|
196 |
+
|
197 |
+
# for each corres, find the scale change w.r.t. its NNs
|
198 |
+
from scipy.spatial.distance import cdist
|
199 |
+
nns = cdist(corres, sub, metric='sqeuclidean').argsort(axis=1)[:,:nn]
|
200 |
+
|
201 |
+
# affine transform for this set of neighboring correspondences
|
202 |
+
pts = sub[nns] # shape = npts x sub x 4
|
203 |
+
# [P1,1] @ A = P2 with A = 3x2 matrix
|
204 |
+
# A = [P1,1]^-1 @ P2
|
205 |
+
P1, P2 = pts[:,:,0:2], pts[:,:,2:4] # each row = list of correspondences
|
206 |
+
P1 = np.concatenate((P1,np.ones_like(P1[:,:,:1])),axis=-1)
|
207 |
+
A = (np.linalg.pinv(P1) @ P2).transpose(0,2,1)
|
208 |
+
|
209 |
+
scale, (angy,angx) = detect_scale_rotation(A.transpose(1,2,0)[:,1::-1])
|
210 |
+
rot = np.arctan2(angy, angx)
|
211 |
+
return scale.clip(min=0.2, max=5), rot
|
212 |
+
|
213 |
+
|
214 |
+
def window1(x, size, w):
|
215 |
+
l = x - int(0.5 + size / 2)
|
216 |
+
r = l + int(0.5 + size)
|
217 |
+
if l < 0: l,r = (0, r - l)
|
218 |
+
if r > w: l,r = (l + w - r, w)
|
219 |
+
if l < 0: l,r = 0,w # larger than width
|
220 |
+
return slice(l,r)
|
221 |
+
|
222 |
+
def window(cx, cy, win_size, scale, img_shape):
|
223 |
+
return (window1(int(cy), win_size*scale, img_shape[0]),
|
224 |
+
window1(int(cx), win_size*scale, img_shape[1]))
|
225 |
+
|
226 |
+
def is_in( pts, window ):
|
227 |
+
x, y = pts.T
|
228 |
+
sly, slx = window
|
229 |
+
return (slx.start <= x) & (x < slx.stop) & (sly.start <= y) & (y < sly.stop)
|
230 |
+
|
231 |
+
def score_windows( valid1, valid2 ):
|
232 |
+
inter = (valid1 & valid2).sum()
|
233 |
+
iou1 = inter / (valid1.sum() + 1e-8)
|
234 |
+
iou2 = inter / (valid2.sum() + 1e-8)
|
235 |
+
return inter * min(iou1, iou2)
|
236 |
+
|
237 |
+
def imresize( img, max_size, resample=Image.ANTIALIAS):
|
238 |
+
if max(img.shape[:2]) > max_size:
|
239 |
+
if img.shape[-1] == 2:
|
240 |
+
img = np.stack([np.float32(Image.fromarray(img[...,i]).resize((max_size,max_size), resample=resample)) for i in range(2)], axis=-1)
|
241 |
+
else:
|
242 |
+
img = np.asarray(Image.fromarray(img).resize((max_size,max_size), resample=resample))
|
243 |
+
assert img.shape[0] == img.shape[1] == max_size, bb()
|
244 |
+
return img
|
245 |
+
|
246 |
+
def wintrf( window, final_img ):
|
247 |
+
wy, wx = window
|
248 |
+
H, W = final_img.shape[:2]
|
249 |
+
T = np.float32((((wx.stop-wx.start)/W, 0, wx.start),
|
250 |
+
(0, (wy.stop-wy.start)/H, wy.start),
|
251 |
+
(0, 0, 1)) )
|
252 |
+
return invh(T)
|
253 |
+
|
254 |
+
|
255 |
+
def collate_ordered(batch, _use_shared_memory=True):
|
256 |
+
pairs, gt = zip(*batch)
|
257 |
+
imgs1, imgs2 = zip(*pairs)
|
258 |
+
assert len(imgs1) == len(imgs2) == len(gt) and isinstance(gt[0], dict)
|
259 |
+
|
260 |
+
# reorder samples (supervised ones first, unsupervised ones last)
|
261 |
+
supervised = [i for i,b in enumerate(gt) if np.isfinite(b['homography']).all()]
|
262 |
+
unsupervsd = [i for i,b in enumerate(gt) if np.isnan(b['homography']).any()]
|
263 |
+
order = supervised + unsupervsd
|
264 |
+
|
265 |
+
def collate( tensors, key=None ):
|
266 |
+
import torch
|
267 |
+
batch = todevice([tensors[i] for i in order], 'cpu')
|
268 |
+
if key == 'corres': return batch # cannot concat
|
269 |
+
if _use_shared_memory: # shared memory tensor to avoid an extra copy
|
270 |
+
numel = sum([x.numel() for x in batch])
|
271 |
+
storage = batch[0].storage()._new_shared(numel)
|
272 |
+
out = batch[0].new(storage)
|
273 |
+
return torch.stack(batch, dim=0, out=out)
|
274 |
+
|
275 |
+
return (collate(imgs1), collate(imgs2)), {k:collate([b[k] for b in gt],k) for k in gt[0]}
|
276 |
+
|
277 |
+
|
278 |
+
if __name__ == '__main__':
|
279 |
+
from datasets import *
|
280 |
+
from tools.viz import show_random_pairs
|
281 |
+
|
282 |
+
db = BalancedCatImagePairs(
|
283 |
+
3125, SyntheticImagePairs(RandomWebImages(0,52),distort='RandomTilting(0.5)'),
|
284 |
+
4875, SyntheticImagePairs(SfM120k_Images(),distort='RandomTilting(0.5)'),
|
285 |
+
8000, SfM120k_Pairs())
|
286 |
+
|
287 |
+
db = FastPairLoader(db,
|
288 |
+
crop=256, transform='RandomRotation(20), RandomScale(256,1536,ar=1.3,can_upscale=True), PixelNoise()',
|
289 |
+
p_swap=0.5, p_flip=0.5, scale_jitter=0, seed=777)
|
290 |
+
|
291 |
+
show_random_pairs(db)
|
datasets/sfm120k.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
from os.path import *
|
7 |
+
|
8 |
+
from .image_set import ImageSet
|
9 |
+
from .pair_dataset import UnsupervisedPairs
|
10 |
+
|
11 |
+
|
12 |
+
class SfM120k_Images (ImageSet):
|
13 |
+
def __init__(self, root='datasets/sfm120k'):
|
14 |
+
self.init_from_folder(join(root,'ims'), recursive=True, listing=True, exts='')
|
15 |
+
|
16 |
+
|
17 |
+
class SfM120k_Pairs (UnsupervisedPairs):
|
18 |
+
def __init__(self, root='datasets/sfm120k'):
|
19 |
+
super().__init__(SfM120k_Images(root=root), join(root,'list_pairs.txt'))
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == '__main__':
|
23 |
+
from tools.viz import show_random_pairs
|
24 |
+
|
25 |
+
db = SfM120k_Pairs()
|
26 |
+
|
27 |
+
show_random_pairs(db)
|
datasets/transforms.py
ADDED
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image, ImageOps
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torchvision import transforms as tvf
|
14 |
+
|
15 |
+
from . import transforms_tools as F
|
16 |
+
from .utils import DatasetWithRng
|
17 |
+
|
18 |
+
'''
|
19 |
+
Example command to try out some transformation chain:
|
20 |
+
|
21 |
+
python -m pytools.transforms --trfs "Scale(384), ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1), RandomRotation(10), RandomTilting(0.5, 'all'), RandomScale(240,320), RandomCrop(224)"
|
22 |
+
'''
|
23 |
+
|
24 |
+
def instanciate_transforms(transforms, use_gpu=False, rng=None, compose=True):
|
25 |
+
''' Instanciate a sequence of transformations.
|
26 |
+
|
27 |
+
transforms: (str, list)
|
28 |
+
Comma-separated list of transformations.
|
29 |
+
Ex: "Rotate(10), Scale(256)"
|
30 |
+
'''
|
31 |
+
try:
|
32 |
+
transforms = transforms or '[]'
|
33 |
+
|
34 |
+
if isinstance(transforms, str):
|
35 |
+
if transforms.lstrip()[0] not in '[(': transforms = f'[{transforms}]'
|
36 |
+
if compose: transforms = f'Compose({transforms})'
|
37 |
+
transforms = eval(transforms)
|
38 |
+
|
39 |
+
if isinstance(transforms, list) and transforms and isinstance(transforms[0], str):
|
40 |
+
transforms = [eval(trf) for trf in transforms]
|
41 |
+
if compose: transforms = Compose(transforms)
|
42 |
+
|
43 |
+
if use_gpu and not isinstance(transforms, nn.Module):
|
44 |
+
while hasattr(transforms,'transforms') or hasattr(transforms,'transform'):
|
45 |
+
transforms = getattr(transforms,'transforms',getattr(transforms,'transform',None))
|
46 |
+
transforms = [trf for trf in transforms if isinstance(trf, nn.Module)]
|
47 |
+
transforms = nn.Sequential(*transforms) if compose else nn.ModuleList(transforms)
|
48 |
+
|
49 |
+
if transforms and rng:
|
50 |
+
for trf in transforms.transforms:
|
51 |
+
assert hasattr(trf, 'rng'), f"Transformation {trf} has no self.rng"
|
52 |
+
trf.rng = rng
|
53 |
+
|
54 |
+
if isinstance(transforms, Compose) and len(transforms.transforms) == 1:
|
55 |
+
transforms = transforms.transforms[0]
|
56 |
+
return transforms
|
57 |
+
|
58 |
+
except Exception as e:
|
59 |
+
print("\nError: Cannot interpret this transform list: %s\n" % transforms)
|
60 |
+
raise e
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
class Compose (DatasetWithRng):
|
65 |
+
def __init__(self, transforms, **rng_seed):
|
66 |
+
super().__init__(**rng_seed)
|
67 |
+
self.transforms = [self.with_same_rng(trf) for trf in transforms]
|
68 |
+
|
69 |
+
def __call__(self, data):
|
70 |
+
for trf in self.transforms:
|
71 |
+
data = trf(data)
|
72 |
+
return data
|
73 |
+
|
74 |
+
|
75 |
+
class Scale (DatasetWithRng):
|
76 |
+
""" Rescale the input PIL.Image to a given size.
|
77 |
+
Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
|
78 |
+
|
79 |
+
The smallest dimension of the resulting image will be = size.
|
80 |
+
|
81 |
+
if largest == True: same behaviour for the largest dimension.
|
82 |
+
|
83 |
+
if not can_upscale: don't upscale
|
84 |
+
if not can_downscale: don't downscale
|
85 |
+
"""
|
86 |
+
def __init__(self, size, interpolation=Image.BILINEAR, largest=False,
|
87 |
+
can_upscale=True, can_downscale=True, **rng_seed):
|
88 |
+
super().__init__(**rng_seed)
|
89 |
+
assert isinstance(size, int) or (len(size) == 2)
|
90 |
+
self.size = size
|
91 |
+
self.interpolation = interpolation
|
92 |
+
self.largest = largest
|
93 |
+
self.can_upscale = can_upscale
|
94 |
+
self.can_downscale = can_downscale
|
95 |
+
|
96 |
+
def __repr__(self):
|
97 |
+
fmt_str = "RandomScale(%s" % str(self.size)
|
98 |
+
if self.largest: fmt_str += ', largest=True'
|
99 |
+
if not self.can_upscale: fmt_str += ', can_upscale=False'
|
100 |
+
if not self.can_downscale: fmt_str += ', can_downscale=False'
|
101 |
+
return fmt_str+')'
|
102 |
+
|
103 |
+
def get_params(self, imsize):
|
104 |
+
w,h = imsize
|
105 |
+
if isinstance(self.size, int):
|
106 |
+
cmp = lambda a,b: (a>=b) if self.largest else (a<=b)
|
107 |
+
if (cmp(w, h) and w == self.size) or (cmp(h, w) and h == self.size):
|
108 |
+
ow, oh = w, h
|
109 |
+
elif cmp(w, h):
|
110 |
+
ow = self.size
|
111 |
+
oh = int(self.size * h / w)
|
112 |
+
else:
|
113 |
+
oh = self.size
|
114 |
+
ow = int(self.size * w / h)
|
115 |
+
else:
|
116 |
+
ow, oh = self.size
|
117 |
+
return ow, oh
|
118 |
+
|
119 |
+
def __call__(self, inp):
|
120 |
+
img = F.grab(inp,'img')
|
121 |
+
w, h = img.size
|
122 |
+
|
123 |
+
size2 = ow, oh = self.get_params(img.size)
|
124 |
+
|
125 |
+
if size2 != img.size:
|
126 |
+
a1, a2 = img.size, size2
|
127 |
+
if (self.can_upscale and min(a1) < min(a2)) or (self.can_downscale and min(a1) > min(a2)):
|
128 |
+
img = img.resize(size2, self.interpolation)
|
129 |
+
|
130 |
+
return F.update(inp, img=img, homography=np.diag((ow/w,oh/h,1)))
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
class RandomScale (Scale):
|
135 |
+
"""Rescale the input PIL.Image to a random size.
|
136 |
+
Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
|
137 |
+
|
138 |
+
Args:
|
139 |
+
min_size (int): min size of the smaller edge of the picture.
|
140 |
+
max_size (int): max size of the smaller edge of the picture.
|
141 |
+
|
142 |
+
ar (float or tuple):
|
143 |
+
max change of aspect ratio (width/height).
|
144 |
+
|
145 |
+
interpolation (int, optional): Desired interpolation. Default is
|
146 |
+
``PIL.Image.BILINEAR``
|
147 |
+
"""
|
148 |
+
|
149 |
+
def __init__(self, min_size, max_size, ar=1, larger=False,
|
150 |
+
can_upscale=False, can_downscale=True, interpolation=Image.BILINEAR):
|
151 |
+
Scale.__init__(self, (min_size,max_size), can_upscale=can_upscale, can_downscale=can_downscale, interpolation=interpolation)
|
152 |
+
assert type(min_size) == type(max_size), 'min_size and max_size can only be 2 ints or 2 floats'
|
153 |
+
assert isinstance(min_size, int) and min_size >= 1 or isinstance(min_size, float) and min_size>0
|
154 |
+
assert isinstance(max_size, (int,float)) and min_size <= max_size
|
155 |
+
self.min_size = min_size
|
156 |
+
self.max_size = max_size
|
157 |
+
if type(ar) in (float,int): ar = (min(1/ar,ar),max(1/ar,ar))
|
158 |
+
assert 0.2 < ar[0] <= ar[1] < 5
|
159 |
+
self.ar = ar
|
160 |
+
self.larger = larger
|
161 |
+
|
162 |
+
def get_params(self, imsize):
|
163 |
+
w,h = imsize
|
164 |
+
if isinstance(self.min_size, float): min_size = int(self.min_size*min(w,h) + 0.5)
|
165 |
+
if isinstance(self.max_size, float): max_size = int(self.max_size*min(w,h) + 0.5)
|
166 |
+
if isinstance(self.min_size, int): min_size = self.min_size
|
167 |
+
if isinstance(self.max_size, int): max_size = self.max_size
|
168 |
+
|
169 |
+
if not(self.can_upscale) and not(self.larger):
|
170 |
+
max_size = min(max_size,min(w,h))
|
171 |
+
|
172 |
+
size = int(0.5 + F.rand_log_uniform(self.rng, min_size, max_size))
|
173 |
+
if not(self.can_upscale) and self.larger:
|
174 |
+
size = min(size, min(w,h))
|
175 |
+
|
176 |
+
ar = F.rand_log_uniform(self.rng, *self.ar) # change of aspect ratio
|
177 |
+
|
178 |
+
if w < h: # image is taller
|
179 |
+
ow = size
|
180 |
+
oh = int(0.5 + size * h / w / ar)
|
181 |
+
if oh < min_size:
|
182 |
+
ow,oh = int(0.5 + ow*float(min_size)/oh),min_size
|
183 |
+
else: # image is wider
|
184 |
+
oh = size
|
185 |
+
ow = int(0.5 + size * w / h * ar)
|
186 |
+
if ow < min_size:
|
187 |
+
ow,oh = min_size,int(0.5 + oh*float(min_size)/ow)
|
188 |
+
|
189 |
+
assert ow >= min_size, 'image too small (width=%d < min_size=%d)' % (ow, min_size)
|
190 |
+
assert oh >= min_size, 'image too small (height=%d < min_size=%d)' % (oh, min_size)
|
191 |
+
return ow, oh
|
192 |
+
|
193 |
+
|
194 |
+
|
195 |
+
class RandomCrop (DatasetWithRng):
|
196 |
+
"""Crop the given PIL Image at a random location.
|
197 |
+
Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
|
198 |
+
|
199 |
+
Args:
|
200 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
201 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
202 |
+
made.
|
203 |
+
padding (int or sequence, optional): Optional padding on each border
|
204 |
+
of the image. Default is 0, i.e no padding. If a sequence of length
|
205 |
+
4 is provided, it is used to pad left, top, right, bottom borders
|
206 |
+
respectively.
|
207 |
+
"""
|
208 |
+
|
209 |
+
def __init__(self, size, padding=0, **rng_seed):
|
210 |
+
super().__init__(**rng_seed)
|
211 |
+
if isinstance(size, int):
|
212 |
+
self.size = (int(size), int(size))
|
213 |
+
else:
|
214 |
+
self.size = size
|
215 |
+
self.padding = padding
|
216 |
+
|
217 |
+
def __repr__(self):
|
218 |
+
return "RandomCrop(%s)" % str(self.size)
|
219 |
+
|
220 |
+
def get_params(self, img, output_size):
|
221 |
+
w, h = img.size
|
222 |
+
th, tw = output_size
|
223 |
+
assert h >= th and w >= tw, "Image of %dx%d is too small for crop %dx%d" % (w,h,tw,th)
|
224 |
+
|
225 |
+
y = self.rng.integers(0, h - th) if h > th else 0
|
226 |
+
x = self.rng.integers(0, w - tw) if w > tw else 0
|
227 |
+
return x, y, tw, th
|
228 |
+
|
229 |
+
def __call__(self, inp):
|
230 |
+
img = F.grab(inp,'img')
|
231 |
+
|
232 |
+
padl = padt = 0
|
233 |
+
if self.padding:
|
234 |
+
if F.is_pil_image(img):
|
235 |
+
img = ImageOps.expand(img, border=self.padding, fill=0)
|
236 |
+
else:
|
237 |
+
assert isinstance(img, F.DummyImg)
|
238 |
+
img = img.expand(border=self.padding)
|
239 |
+
if isinstance(self.padding, int):
|
240 |
+
padl = padt = self.padding
|
241 |
+
else:
|
242 |
+
padl, padt = self.padding[0:2]
|
243 |
+
|
244 |
+
i, j, tw, th = self.get_params(img, self.size)
|
245 |
+
img = img.crop((i, j, i+tw, j+th))
|
246 |
+
|
247 |
+
return F.update(inp, img=img, homography=np.float32(((1,0,padl-i),(0,1,padt-j),(0,0,1))))
|
248 |
+
|
249 |
+
|
250 |
+
class CenterCrop (RandomCrop):
|
251 |
+
"""Crops the given PIL Image at the center.
|
252 |
+
Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
|
253 |
+
|
254 |
+
Args:
|
255 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
256 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
257 |
+
made.
|
258 |
+
"""
|
259 |
+
@staticmethod
|
260 |
+
def get_params(img, output_size):
|
261 |
+
w, h = img.size
|
262 |
+
th, tw = output_size
|
263 |
+
y = int(0.5 +((h - th) / 2.))
|
264 |
+
x = int(0.5 +((w - tw) / 2.))
|
265 |
+
return x, y, tw, th
|
266 |
+
|
267 |
+
|
268 |
+
class RandomRotation (DatasetWithRng):
|
269 |
+
"""Rescale the input PIL.Image to a random size.
|
270 |
+
Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
|
271 |
+
|
272 |
+
Args:
|
273 |
+
degrees (float):
|
274 |
+
rotation angle.
|
275 |
+
|
276 |
+
interpolation (int, optional): Desired interpolation. Default is
|
277 |
+
``PIL.Image.BILINEAR``
|
278 |
+
"""
|
279 |
+
|
280 |
+
def __init__(self, degrees, interpolation=Image.BILINEAR, **rng_seed):
|
281 |
+
super().__init__(**rng_seed)
|
282 |
+
self.degrees = degrees
|
283 |
+
self.interpolation = interpolation
|
284 |
+
|
285 |
+
def __repr__(self):
|
286 |
+
return f"RandomRotation({self.degrees})"
|
287 |
+
|
288 |
+
def __call__(self, inp):
|
289 |
+
img = F.grab(inp,'img')
|
290 |
+
w, h = img.size
|
291 |
+
|
292 |
+
angle = self.rng.uniform(-self.degrees, self.degrees)
|
293 |
+
|
294 |
+
img = img.rotate(angle, resample=self.interpolation)
|
295 |
+
w2, h2 = img.size
|
296 |
+
|
297 |
+
trf = F.translate(w2/2,h2/2) @ F.rotate(-angle * np.pi/180) @ F.translate(-w/2,-h/2)
|
298 |
+
return F.update(inp, img=img, homography=trf)
|
299 |
+
|
300 |
+
|
301 |
+
class RandomTilting (DatasetWithRng):
|
302 |
+
"""Apply a random tilting (left, right, up, down) to the input PIL.Image
|
303 |
+
Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
|
304 |
+
|
305 |
+
Args:
|
306 |
+
maginitude (float):
|
307 |
+
maximum magnitude of the random skew (value between 0 and 1)
|
308 |
+
directions (string):
|
309 |
+
tilting directions allowed (all, left, right, up, down)
|
310 |
+
examples: "all", "left,right", "up-down-right"
|
311 |
+
"""
|
312 |
+
|
313 |
+
def __init__(self, magnitude, directions='all', **rng_seed):
|
314 |
+
super().__init__(**rng_seed)
|
315 |
+
self.magnitude = magnitude
|
316 |
+
self.directions = directions.lower().replace(',',' ').replace('-',' ')
|
317 |
+
|
318 |
+
def __repr__(self):
|
319 |
+
return "RandomTilt(%g, '%s')" % (self.magnitude,self.directions)
|
320 |
+
|
321 |
+
def __call__(self, inp):
|
322 |
+
img = F.grab(inp,'img')
|
323 |
+
w, h = img.size
|
324 |
+
|
325 |
+
x1,y1,x2,y2 = 0,0,h,w
|
326 |
+
original_plane = [(y1, x1), (y2, x1), (y2, x2), (y1, x2)]
|
327 |
+
|
328 |
+
max_skew_amount = max(w, h)
|
329 |
+
max_skew_amount = int(np.ceil(max_skew_amount * self.magnitude))
|
330 |
+
skew_amount = self.rng.integers(1, max_skew_amount)
|
331 |
+
|
332 |
+
if self.directions == 'all':
|
333 |
+
choices = [0,1,2,3]
|
334 |
+
else:
|
335 |
+
dirs = ['left', 'right', 'up', 'down']
|
336 |
+
choices = []
|
337 |
+
for d in self.directions.split():
|
338 |
+
try:
|
339 |
+
choices.append(dirs.index(d))
|
340 |
+
except:
|
341 |
+
raise ValueError('Tilting direction %s not recognized' % d)
|
342 |
+
|
343 |
+
skew_direction = self.rng.choice(choices)
|
344 |
+
|
345 |
+
# print('randomtitlting: ', skew_amount, skew_direction) # to debug random
|
346 |
+
|
347 |
+
if skew_direction == 0:
|
348 |
+
# Left Tilt
|
349 |
+
new_plane = [(y1, x1 - skew_amount), # Top Left
|
350 |
+
(y2, x1), # Top Right
|
351 |
+
(y2, x2), # Bottom Right
|
352 |
+
(y1, x2 + skew_amount)] # Bottom Left
|
353 |
+
elif skew_direction == 1:
|
354 |
+
# Right Tilt
|
355 |
+
new_plane = [(y1, x1), # Top Left
|
356 |
+
(y2, x1 - skew_amount), # Top Right
|
357 |
+
(y2, x2 + skew_amount), # Bottom Right
|
358 |
+
(y1, x2)] # Bottom Left
|
359 |
+
elif skew_direction == 2:
|
360 |
+
# Forward Tilt
|
361 |
+
new_plane = [(y1 - skew_amount, x1), # Top Left
|
362 |
+
(y2 + skew_amount, x1), # Top Right
|
363 |
+
(y2, x2), # Bottom Right
|
364 |
+
(y1, x2)] # Bottom Left
|
365 |
+
elif skew_direction == 3:
|
366 |
+
# Backward Tilt
|
367 |
+
new_plane = [(y1, x1), # Top Left
|
368 |
+
(y2, x1), # Top Right
|
369 |
+
(y2 + skew_amount, x2), # Bottom Right
|
370 |
+
(y1 - skew_amount, x2)] # Bottom Left
|
371 |
+
|
372 |
+
# To calculate the coefficients required by PIL for the perspective skew,
|
373 |
+
# see the following Stack Overflow discussion: https://goo.gl/sSgJdj
|
374 |
+
homography = F.homography_from_4pts(original_plane, new_plane)
|
375 |
+
img = img.transform(img.size, Image.PERSPECTIVE, homography, resample=Image.BICUBIC)
|
376 |
+
|
377 |
+
homography = np.linalg.pinv(np.float32(homography+(1,)).reshape(3,3))
|
378 |
+
return F.update(inp, img=img, homography=homography)
|
379 |
+
|
380 |
+
|
381 |
+
RandomHomography = RandomTilt = RandomTilting # redefinition
|
382 |
+
|
383 |
+
|
384 |
+
class Homography(object):
|
385 |
+
"""Apply a known tilting to an image
|
386 |
+
"""
|
387 |
+
def __init__(self, *homography):
|
388 |
+
assert len(homography) == 8
|
389 |
+
self.homography = homography
|
390 |
+
|
391 |
+
def __call__(self, inp):
|
392 |
+
img = F.grab(inp, 'img')
|
393 |
+
homography = self.homography
|
394 |
+
|
395 |
+
img = img.transform(img.size, Image.PERSPECTIVE, homography, resample=Image.BICUBIC)
|
396 |
+
|
397 |
+
homography = np.linalg.pinv(np.float32(list(homography)+[1]).reshape(3,3))
|
398 |
+
return F.update(inp, img=img, homography=homography)
|
399 |
+
|
400 |
+
|
401 |
+
|
402 |
+
class StillTransform (DatasetWithRng):
|
403 |
+
""" Takes and return an image, without changing its shape or geometry.
|
404 |
+
"""
|
405 |
+
def _transform(self, img):
|
406 |
+
raise NotImplementedError()
|
407 |
+
|
408 |
+
def __call__(self, inp):
|
409 |
+
img = F.grab(inp,'img')
|
410 |
+
|
411 |
+
# transform the image (size should not change)
|
412 |
+
try:
|
413 |
+
img = self._transform(img)
|
414 |
+
except TypeError:
|
415 |
+
pass
|
416 |
+
|
417 |
+
return F.update(inp, img=img)
|
418 |
+
|
419 |
+
|
420 |
+
|
421 |
+
class PixelNoise (StillTransform):
|
422 |
+
""" Takes an image, and add random white noise.
|
423 |
+
"""
|
424 |
+
def __init__(self, ampl=20, **rng_seed):
|
425 |
+
super().__init__(**rng_seed)
|
426 |
+
assert 0 <= ampl < 255
|
427 |
+
self.ampl = ampl
|
428 |
+
|
429 |
+
def __repr__(self):
|
430 |
+
return "PixelNoise(%g)" % self.ampl
|
431 |
+
|
432 |
+
def _transform(self, img):
|
433 |
+
img = np.float32(img)
|
434 |
+
img += self.rng.uniform(0.5-self.ampl/2, 0.5+self.ampl/2, size=img.shape)
|
435 |
+
return Image.fromarray(np.uint8(img.clip(0,255)))
|
436 |
+
|
437 |
+
|
438 |
+
|
439 |
+
class ColorJitter (StillTransform):
|
440 |
+
"""Randomly change the brightness, contrast and saturation of an image.
|
441 |
+
Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
|
442 |
+
|
443 |
+
Args:
|
444 |
+
brightness (float): How much to jitter brightness. brightness_factor
|
445 |
+
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
|
446 |
+
contrast (float): How much to jitter contrast. contrast_factor
|
447 |
+
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
|
448 |
+
saturation (float): How much to jitter saturation. saturation_factor
|
449 |
+
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
|
450 |
+
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
|
451 |
+
[-hue, hue]. Should be >=0 and <= 0.5.
|
452 |
+
"""
|
453 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
454 |
+
self.brightness = brightness
|
455 |
+
self.contrast = contrast
|
456 |
+
self.saturation = saturation
|
457 |
+
self.hue = hue
|
458 |
+
|
459 |
+
def __repr__(self):
|
460 |
+
return "ColorJitter(%g,%g,%g,%g)" % (
|
461 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
462 |
+
|
463 |
+
def get_params(self, brightness, contrast, saturation, hue):
|
464 |
+
"""Get a randomized transform to be applied on image.
|
465 |
+
Arguments are same as that of __init__.
|
466 |
+
Returns:
|
467 |
+
Transform which randomly adjusts brightness, contrast and
|
468 |
+
saturation in a random order.
|
469 |
+
"""
|
470 |
+
transforms = []
|
471 |
+
if brightness > 0:
|
472 |
+
brightness_factor = self.rng.uniform(max(0, 1 - brightness), 1 + brightness)
|
473 |
+
transforms.append(tvf.Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
|
474 |
+
|
475 |
+
if contrast > 0:
|
476 |
+
contrast_factor = self.rng.uniform(max(0, 1 - contrast), 1 + contrast)
|
477 |
+
transforms.append(tvf.Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
|
478 |
+
|
479 |
+
if saturation > 0:
|
480 |
+
saturation_factor = self.rng.uniform(max(0, 1 - saturation), 1 + saturation)
|
481 |
+
transforms.append(tvf.Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
|
482 |
+
|
483 |
+
if hue > 0:
|
484 |
+
hue_factor = self.rng.uniform(-hue, hue)
|
485 |
+
transforms.append(tvf.Lambda(lambda img: F.adjust_hue(img, hue_factor)))
|
486 |
+
|
487 |
+
# print('colorjitter: ', brightness_factor, contrast_factor, saturation_factor, hue_factor) # to debug random seed
|
488 |
+
self.rng.shuffle(transforms)
|
489 |
+
transform = tvf.Compose(transforms)
|
490 |
+
return transform
|
491 |
+
|
492 |
+
def _transform(self, img):
|
493 |
+
transform = self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
|
494 |
+
return transform(img)
|
495 |
+
|
496 |
+
|
497 |
+
def pil_loader(path, mode='RGB'):
|
498 |
+
with warnings.catch_warnings():
|
499 |
+
warnings.simplefilter("ignore")
|
500 |
+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
501 |
+
with (path if hasattr(path,'read') else open(path, 'rb')) as f:
|
502 |
+
img = Image.open(f)
|
503 |
+
return img.convert(mode)
|
504 |
+
|
505 |
+
def torchvision_loader(path, mode='RGB'):
|
506 |
+
from torchvision.io import read_file, decode_image, read_image, image
|
507 |
+
return read_image(getattr(path,'name',path), mode=getattr(image.ImageReadMode,mode))
|
508 |
+
|
509 |
+
|
510 |
+
|
511 |
+
if __name__ == '__main__':
|
512 |
+
from matplotlib import pyplot as pl
|
513 |
+
import argparse
|
514 |
+
|
515 |
+
parser = argparse.ArgumentParser("Script to try out and visualize transformations")
|
516 |
+
parser.add_argument('--img', type=str, default='imgs/test.png', help='input image')
|
517 |
+
parser.add_argument('--trfs', type=str, required=True, help='list of transformations')
|
518 |
+
parser.add_argument('--layout', type=int, nargs=2, default=(3,3), help='nb of rows,cols')
|
519 |
+
args = parser.parse_args()
|
520 |
+
|
521 |
+
img = dict(img=pil_loader(args.img))
|
522 |
+
|
523 |
+
trfs = instanciate_transforms(args.trfs)
|
524 |
+
|
525 |
+
pl.subplots_adjust(0,0,1,1)
|
526 |
+
nr,nc = args.layout
|
527 |
+
|
528 |
+
while True:
|
529 |
+
t0 = now()
|
530 |
+
imgs2 = [trfs(img) for _ in range(nr*nc)]
|
531 |
+
|
532 |
+
for j in range(nr):
|
533 |
+
for i in range(nc):
|
534 |
+
pl.subplot(nr,nc,i+j*nc+1)
|
535 |
+
img2 = img if i==j==0 else imgs2.pop() #trfs(img)
|
536 |
+
img2 = img2['img']
|
537 |
+
pl.imshow(img2)
|
538 |
+
pl.xlabel("%d x %d" % img2.size)
|
539 |
+
print(f'Took {now() - t0:.2f} seconds')
|
540 |
+
pl.show()
|
datasets/transforms_tools.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image, ImageOps, ImageEnhance
|
8 |
+
|
9 |
+
|
10 |
+
def grab( data, *fields ):
|
11 |
+
''' Called to extract fields from a dictionary
|
12 |
+
'''
|
13 |
+
if isinstance(data, dict):
|
14 |
+
res = []
|
15 |
+
for f in fields:
|
16 |
+
res.append( data[f] )
|
17 |
+
return res[0] if len(fields) == 1 else tuple(res)
|
18 |
+
|
19 |
+
else: # or it must be the img directly
|
20 |
+
assert fields == ('img',) and isinstance(data, (np.ndarray, Image.Image)), \
|
21 |
+
f"data should be an image, not {type(data)}!"
|
22 |
+
return data
|
23 |
+
|
24 |
+
|
25 |
+
def update( data, **fields):
|
26 |
+
''' Called to update the img_and_label
|
27 |
+
'''
|
28 |
+
if isinstance( data, dict):
|
29 |
+
if 'homography' in fields and 'homography' in data:
|
30 |
+
data['homography'] = fields.pop('homography') @ data['homography']
|
31 |
+
data.update(fields)
|
32 |
+
if 'img' in fields:
|
33 |
+
data['imsize'] = data['img'].size
|
34 |
+
return data
|
35 |
+
|
36 |
+
else: # or it must be the img directly
|
37 |
+
return fields['img']
|
38 |
+
|
39 |
+
|
40 |
+
def rand_log_uniform(rng, a, b):
|
41 |
+
return np.exp(rng.uniform(np.log(a),np.log(b)))
|
42 |
+
|
43 |
+
|
44 |
+
def translate(tx, ty):
|
45 |
+
return np.float32(((1,0,tx),(0,1,ty,),(0,0,1)))
|
46 |
+
|
47 |
+
def rotate(angle):
|
48 |
+
return np.float32(((np.cos(angle),-np.sin(angle),0),(np.sin(angle),np.cos(angle),0),(0,0,1)))
|
49 |
+
|
50 |
+
|
51 |
+
def is_pil_image(img):
|
52 |
+
return isinstance(img, Image.Image)
|
53 |
+
|
54 |
+
|
55 |
+
def homography_from_4pts(pts_cur, pts_new):
|
56 |
+
"pts_cur and pts_new = 4x2 point array, in [(x,y),...] format"
|
57 |
+
matrix = []
|
58 |
+
for p1, p2 in zip(pts_new, pts_cur):
|
59 |
+
matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
|
60 |
+
matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
|
61 |
+
A = np.matrix(matrix, dtype=np.float)
|
62 |
+
B = np.array(pts_cur).reshape(8)
|
63 |
+
|
64 |
+
homography = np.dot(np.linalg.pinv(A), B)
|
65 |
+
homography = tuple(np.array(homography).reshape(8))
|
66 |
+
#print(homography)
|
67 |
+
return homography
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
|
datasets/utils.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
class DatasetWithRng:
|
11 |
+
""" Make sure that RNG is distributed properly when torch.dataloader() is used
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, seed=None):
|
15 |
+
self.seed = seed
|
16 |
+
self.rng = np.random.default_rng(seed)
|
17 |
+
self._rng_children = set()
|
18 |
+
|
19 |
+
def with_same_rng(self, dataset=None):
|
20 |
+
if dataset is not None:
|
21 |
+
assert isinstance(dataset, DatasetWithRng) and hasattr(dataset, 'rng'), bb()
|
22 |
+
self._rng_children.add( dataset )
|
23 |
+
|
24 |
+
# update all registered children
|
25 |
+
for db in self._rng_children:
|
26 |
+
db.rng = self.rng
|
27 |
+
db.with_same_rng() # recursive call
|
28 |
+
return dataset
|
29 |
+
|
30 |
+
def init_worker(self, tid):
|
31 |
+
if self.seed is None:
|
32 |
+
self.rng = np.random.default_rng()
|
33 |
+
else:
|
34 |
+
self.rng = np.random.default_rng(self.seed + tid)
|
35 |
+
|
36 |
+
|
37 |
+
class WorkerWithRngInit:
|
38 |
+
" Dataset inherits from datasets.DatasetWithRng() and has an init_worker() function "
|
39 |
+
def __call__(self, tid):
|
40 |
+
torch.utils.data.get_worker_info().dataset.init_worker(tid)
|
41 |
+
|
42 |
+
|
43 |
+
def corres_from_homography(homography, W, H, grid=64):
|
44 |
+
s = max(1, min(W, H) // grid) # at least `grid` points in smallest dim
|
45 |
+
sx, sy = [slice(s//2, l, s) for l in (W, H)]
|
46 |
+
grid1 = np.mgrid[sy, sx][::-1].reshape(2,-1).T # (x1,y1) grid
|
47 |
+
|
48 |
+
grid2 = applyh(homography, grid1)
|
49 |
+
scale = np.sqrt(np.abs(np.linalg.det(jacobianh(homography, grid1).T)))
|
50 |
+
|
51 |
+
corres = np.c_[grid1, grid2, np.ones_like(scale), np.zeros_like(scale), scale]
|
52 |
+
return corres
|
53 |
+
|
54 |
+
|
55 |
+
def invh( H ):
|
56 |
+
return np.linalg.inv(H)
|
57 |
+
|
58 |
+
|
59 |
+
def applyh(H, p, ncol=2, norm=True):
|
60 |
+
""" Apply the homography to a list of 2d points in homogeneous coordinates.
|
61 |
+
|
62 |
+
H: Homography (...x3x3 matrix/tensor)
|
63 |
+
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
|
64 |
+
|
65 |
+
Returns an array of projected 2d points.
|
66 |
+
"""
|
67 |
+
if isinstance(H, np.ndarray):
|
68 |
+
p = np.asarray(p)
|
69 |
+
elif isinstance(H, torch.Tensor):
|
70 |
+
p = torch.as_tensor(p, dtype=H.dtype)
|
71 |
+
|
72 |
+
if p.shape[-1]+1 == H.shape[-1]:
|
73 |
+
H = H.swapaxes(-1,-2) # transpose H
|
74 |
+
p = p @ H[...,:-1,:] + H[...,-1:,:]
|
75 |
+
else:
|
76 |
+
p = H @ p.T
|
77 |
+
if p.ndim >= 2: p = p.swapaxes(-1,-2)
|
78 |
+
|
79 |
+
if norm:
|
80 |
+
p /= p[...,-1:]
|
81 |
+
return p[...,:ncol]
|
82 |
+
|
83 |
+
|
84 |
+
def jacobianh(H, p):
|
85 |
+
""" H is an homography that maps: f_H(x,y) --> (f_1, f_2)
|
86 |
+
So the Jacobian J_H evaluated at p=(x,y) is a 2x2 matrix
|
87 |
+
Output shape = (2, 2, N) = (f_, xy, N)
|
88 |
+
|
89 |
+
Example of derivative:
|
90 |
+
numx a*X + b*Y + c*Z
|
91 |
+
since x = ----- = ---------------
|
92 |
+
denom u*X + v*Y + w*Z
|
93 |
+
|
94 |
+
numx' * denom - denom' * numx a*denom - u*numx
|
95 |
+
dx/dX = ----------------------------- = ----------------
|
96 |
+
denom**2 denom**2
|
97 |
+
"""
|
98 |
+
(a, b, c), (d, e, f), (u, v, w) = H
|
99 |
+
numx, numy, denom = applyh(H, p, ncol=3, norm=False).T
|
100 |
+
|
101 |
+
# column x column x
|
102 |
+
J = np.float32(((a*denom - u*numx, b*denom - v*numx), # row f_1
|
103 |
+
(d*denom - u*numy, e*denom - v*numy))) # row f_2
|
104 |
+
return J / np.where(denom, denom*denom, np.nan)
|
datasets/web_images.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
import os, os.path as osp
|
7 |
+
|
8 |
+
from tqdm import trange
|
9 |
+
from .image_set import ImageSet, verify_img
|
10 |
+
|
11 |
+
|
12 |
+
class RandomWebImages (ImageSet):
|
13 |
+
""" 1 million distractors from Oxford and Paris Revisited
|
14 |
+
see http://ptak.felk.cvut.cz/revisitop/revisitop1m/
|
15 |
+
"""
|
16 |
+
def __init__(self, start=0, end=52, root="datasets/revisitop1m"):
|
17 |
+
bar = None
|
18 |
+
imgs = []
|
19 |
+
for i in range(start, end):
|
20 |
+
try:
|
21 |
+
# read cached list
|
22 |
+
img_list_path = osp.join(root, "image_list_%d.txt"%i)
|
23 |
+
cached_imgs = [e.strip() for e in open(img_list_path)]
|
24 |
+
assert cached_imgs, f"Cache '{img_list_path}' is empty!"
|
25 |
+
imgs += cached_imgs
|
26 |
+
|
27 |
+
except IOError:
|
28 |
+
if bar is None:
|
29 |
+
bar = trange(start, 4*end, desc='Caching')
|
30 |
+
bar.update(4*i)
|
31 |
+
|
32 |
+
# create it
|
33 |
+
imgs = []
|
34 |
+
for d in range(i*4,(i+1)*4): # 4096 folders in total, on average 256 each
|
35 |
+
key = hex(d)[2:].zfill(3)
|
36 |
+
folder = osp.join(root, key)
|
37 |
+
if not osp.isdir(folder): continue
|
38 |
+
imgs += [f for f in os.listdir(folder) if verify_img(osp.join(folder, f), exts='.jpg')]
|
39 |
+
bar.update(1)
|
40 |
+
assert imgs, f"No images found in {folder}/"
|
41 |
+
open(img_list_path,'w').write('\n'.join(imgs))
|
42 |
+
imgs += imgs
|
43 |
+
|
44 |
+
if bar: bar.update(bar.total - bar.n)
|
45 |
+
super().__init__(root, imgs)
|
46 |
+
|
47 |
+
def get_image_path(self, idx):
|
48 |
+
key = self.imgs[idx]
|
49 |
+
return osp.join(self.root, key[:3], key)
|
50 |
+
|
demo_warping.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
import os, os.path as osp
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
from tools.viz import pl, noticks
|
11 |
+
|
12 |
+
""" This script will warp (deform) img2 so that it fits img1
|
13 |
+
|
14 |
+
>> In case of memory failure (not enough GPU memory):
|
15 |
+
try adding '--resize 400 300' (or larger values if possible) to the _exec(...) command below.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def parse_args():
|
19 |
+
import argparse
|
20 |
+
parser = argparse.ArgumentParser('PUMP demo script for the image warping demo')
|
21 |
+
|
22 |
+
parser.add_argument('--img1', default='datasets/demo_warp/mountains_src.jpg')
|
23 |
+
parser.add_argument('--img2', default='datasets/demo_warp/mountains_tgt.jpg')
|
24 |
+
parser.add_argument('--output', default='results/demo_warp')
|
25 |
+
|
26 |
+
parser.add_argument('--just-print', action='store_true', help='just print commands')
|
27 |
+
return parser.parse_args()
|
28 |
+
|
29 |
+
|
30 |
+
def main( args ):
|
31 |
+
run_pump(args) and run_demo_warp(args)
|
32 |
+
|
33 |
+
|
34 |
+
def run_pump(args):
|
35 |
+
output_path = osp.join(args.output, args.img1, args.img2+'.corres')
|
36 |
+
if osp.isfile(output_path): return True
|
37 |
+
|
38 |
+
return _exec(f'''python test_singlescale_recursive.py
|
39 |
+
--img1 {args.img1}
|
40 |
+
--img2 {args.img2}
|
41 |
+
--post-filter densify=True
|
42 |
+
--output {output_path}''')
|
43 |
+
|
44 |
+
|
45 |
+
def run_demo_warp(args):
|
46 |
+
corres_path = osp.join(args.output, args.img1, args.img2+'.corres')
|
47 |
+
corres = np.load(corres_path)['corres']
|
48 |
+
|
49 |
+
img1 = Image.open(args.img1).convert('RGB')
|
50 |
+
img2 = Image.open(args.img2).convert('RGB')
|
51 |
+
|
52 |
+
W, H = img1.size
|
53 |
+
warped_img2 = warp_img(np.asarray(img2), corres[:,2:4].reshape(H,W,2))
|
54 |
+
|
55 |
+
pl.figure('Warping demo')
|
56 |
+
|
57 |
+
noticks(pl.subplot(211))
|
58 |
+
pl.imshow( img2 )
|
59 |
+
pl.title('Source image')
|
60 |
+
|
61 |
+
noticks(pl.subplot(223))
|
62 |
+
pl.imshow( img1 )
|
63 |
+
pl.title('Target image')
|
64 |
+
|
65 |
+
noticks(pl.subplot(224))
|
66 |
+
pl.imshow( warped_img2 )
|
67 |
+
pl.title('Source image warped to match target')
|
68 |
+
|
69 |
+
pl.tight_layout()
|
70 |
+
pl.show(block=True)
|
71 |
+
|
72 |
+
|
73 |
+
def warp_img( img, absolute_flow ):
|
74 |
+
H1, W1, TWO = absolute_flow.shape
|
75 |
+
H2, W2, THREE = img.shape
|
76 |
+
assert TWO == 2 and THREE == 3
|
77 |
+
|
78 |
+
warp = absolute_flow.round().astype(int)
|
79 |
+
invalid = (warp[:,:,0]<0) | (warp[:,:,0]>=W2) | (warp[:,:,1]<0) | (warp[:,:,1]>=H2)
|
80 |
+
|
81 |
+
warp[:,:,0] = warp[:,:,0].clip(min=0, max=W2-1)
|
82 |
+
warp[:,:,1] = warp[:,:,1].clip(min=0, max=H2-1)
|
83 |
+
warp = warp[:,:,0] + W2*warp[:,:,1]
|
84 |
+
|
85 |
+
warped_img = np.asarray(img).reshape(-1,3)[warp].reshape(H1,W1,3)
|
86 |
+
return warped_img
|
87 |
+
|
88 |
+
|
89 |
+
def _exec(cmd):
|
90 |
+
# strip & remove \n
|
91 |
+
cmd = ' '.join(cmd.split())
|
92 |
+
|
93 |
+
if args.just_print:
|
94 |
+
print(cmd)
|
95 |
+
return False
|
96 |
+
else:
|
97 |
+
return os.WEXITSTATUS(os.system(cmd)) == 0
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == '__main__':
|
101 |
+
args = parse_args()
|
102 |
+
main( args )
|
download_training_data.sh
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 3.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
CODE_ROOT=`pwd`
|
6 |
+
if [ ! -e datasets ]; then
|
7 |
+
echo "Error: missing datasets/ folder"
|
8 |
+
echo "First, create a folder that can host (at least) 15 GB of data."
|
9 |
+
echo "Then, create a soft-link named 'data' that points to it."
|
10 |
+
exit -1
|
11 |
+
fi
|
12 |
+
|
13 |
+
# download some web images from the revisitop1m dataset
|
14 |
+
WEB_ROOT=datasets/revisitop1m
|
15 |
+
mkdir -p $WEB_ROOT
|
16 |
+
cd $WEB_ROOT
|
17 |
+
if [ ! -e 0d3 ]; then
|
18 |
+
for i in {1..5}; do
|
19 |
+
echo "Installing the web images dataset ($i/5)..."
|
20 |
+
if [ ! -f revisitop1m.$i.tar.gz ]; then
|
21 |
+
wget http://ptak.felk.cvut.cz/revisitop/revisitop1m/jpg/revisitop1m.$i.tar.gz
|
22 |
+
fi
|
23 |
+
tar -xzvf revisitop1m.$i.tar.gz
|
24 |
+
rm -f revisitop1m.$i.tar.gz
|
25 |
+
done
|
26 |
+
fi
|
27 |
+
cd $CODE_ROOT
|
28 |
+
|
29 |
+
# download SfM120k pairs
|
30 |
+
SFM_ROOT=datasets/sfm120k
|
31 |
+
mkdir -p $SFM_ROOT
|
32 |
+
cd $SFM_ROOT
|
33 |
+
if [ ! -e "ims" ]; then
|
34 |
+
echo "Downloading the SfM120k dataset..."
|
35 |
+
fname=ims.tar.gz
|
36 |
+
if [ ! -f $fname ]; then
|
37 |
+
wget http://cmp.felk.cvut.cz/cnnimageretrieval/data/train/ims/ims.tar.gz
|
38 |
+
fi
|
39 |
+
tar -xzvf $fname -C ims
|
40 |
+
rm -f $fname
|
41 |
+
fi
|
42 |
+
if [ ! -e "corres" ]; then
|
43 |
+
echo "Installing the SfM120k dataset..."
|
44 |
+
fname=corres.tar.gz
|
45 |
+
if [ ! -f $meta ]; then
|
46 |
+
wget https://download.europe.naverlabs.com/corres.tar.gz
|
47 |
+
fi
|
48 |
+
tar -xzvf $fname
|
49 |
+
rm -f $fname
|
50 |
+
fi
|
51 |
+
cd $CODE_ROOT
|
52 |
+
|
53 |
+
echo "Done!"
|
imgs/demo_warp.jpg
ADDED
![]() |
imgs/overview.png
ADDED
![]() |
imgs/teaser_paper.jpg
ADDED
![]() |
imgs/test.png
ADDED
![]() |
post_filter.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
import pdb, sys, os
|
6 |
+
import argparse
|
7 |
+
import numpy as np
|
8 |
+
from scipy.sparse import coo_matrix, csr_matrix, triu, csgraph
|
9 |
+
|
10 |
+
import core.functional as myF
|
11 |
+
from tools.common import image, image_with_trf
|
12 |
+
from tools.viz import dbgfig, show_correspondences
|
13 |
+
|
14 |
+
|
15 |
+
def arg_parser():
|
16 |
+
parser = argparse.ArgumentParser("Post-filtering of Deep matching correspondences")
|
17 |
+
|
18 |
+
parser.add_argument("--img1", required=True, help="path to first image")
|
19 |
+
parser.add_argument("--img2", required=True, help="path to second image")
|
20 |
+
parser.add_argument("--resize", default=0, type=int, help="prior image downsize (0 if recursive)")
|
21 |
+
parser.add_argument("--corres", required=True, help="input path")
|
22 |
+
parser.add_argument("--output", default="", help="filtered corres output")
|
23 |
+
|
24 |
+
parser.add_argument("--locality", type=float, default=2, help="tolerance to deformation")
|
25 |
+
parser.add_argument("--min-cc-size", type=int, default=50, help="min connex-component size")
|
26 |
+
parser.add_argument("--densify", default='no', choices=['no','full','cc','convex'], help="output pixel-dense corres field")
|
27 |
+
parser.add_argument("--dense-side", default='left', choices=['left','right'], help="img to densify")
|
28 |
+
|
29 |
+
parser.add_argument("--verbose", "-v", type=int, default=0, help="verbosity level")
|
30 |
+
parser.add_argument("--dbg", type=str, nargs='+', default=(), help="debug options")
|
31 |
+
return parser
|
32 |
+
|
33 |
+
|
34 |
+
def main(args):
|
35 |
+
import test_singlescale as pump
|
36 |
+
corres = np.load(args.corres)['corres']
|
37 |
+
imgs = tuple(map(image, pump.Main.load_images(args)))
|
38 |
+
|
39 |
+
if dbgfig('raw',args.dbg):
|
40 |
+
show_correspondences(*imgs, corres)
|
41 |
+
|
42 |
+
corres = filter_corres( *imgs, corres,
|
43 |
+
locality=args.locality, min_cc_size=args.min_cc_size,
|
44 |
+
densify=args.densify, dense_side=args.dense_side,
|
45 |
+
verbose=args.verbose, dbg=args.dbg)
|
46 |
+
|
47 |
+
if dbgfig('viz',args.dbg):
|
48 |
+
show_correspondences(*imgs, corres)
|
49 |
+
|
50 |
+
return pump.save_output( args, corres )
|
51 |
+
|
52 |
+
|
53 |
+
def filter_corres( img0, img1, corres,
|
54 |
+
locality = None, # graph edge locality
|
55 |
+
min_cc_size = None, # min CC size
|
56 |
+
densify = None,
|
57 |
+
dense_side = None,
|
58 |
+
verbose = 0, dbg=()):
|
59 |
+
|
60 |
+
if None in (locality, min_cc_size, densify, dense_side):
|
61 |
+
default_params = arg_parser()
|
62 |
+
locality = locality or default_params.get_default('locality')
|
63 |
+
min_cc_size = min_cc_size or default_params.get_default('min_cc_size')
|
64 |
+
densify = densify or default_params.get_default('densify')
|
65 |
+
dense_side = dense_side or default_params.get_default('dense_side')
|
66 |
+
|
67 |
+
img0, trf0 = img0 if isinstance(img0,tuple) else (img0, np.eye(3))
|
68 |
+
img1, trf1 = img1 if isinstance(img1,tuple) else (img1, np.eye(3))
|
69 |
+
assert isinstance(img0, np.ndarray) and isinstance(img1, np.ndarray)
|
70 |
+
|
71 |
+
corres = myF.affmul((np.linalg.inv(trf0),np.linalg.inv(trf1)), corres)
|
72 |
+
n_corres = len(corres)
|
73 |
+
if verbose: print(f'>> input: {len(corres)} correspondences')
|
74 |
+
|
75 |
+
graph = compute_graph(corres, max_dis=locality*4)
|
76 |
+
if verbose: print(f'>> {locality=}: {graph.nnz} nodes in graph')
|
77 |
+
|
78 |
+
cc_sizes = measure_connected_components(graph)
|
79 |
+
corres[:,4] += np.log2(cc_sizes)
|
80 |
+
corres = corres[cc_sizes > min_cc_size]
|
81 |
+
if verbose: print(f'>> {min_cc_size=}: remaining {len(corres)} correspondences')
|
82 |
+
|
83 |
+
final = myF.affmul((trf0,trf1), corres)
|
84 |
+
|
85 |
+
if densify != 'no':
|
86 |
+
# densify correspondences
|
87 |
+
if dense_side == 'right': # temporary swap
|
88 |
+
final = final[:,[2,3,0,1]]
|
89 |
+
H = round(img1.shape[0] / trf1[1,1])
|
90 |
+
W = round(img1.shape[1] / trf1[0,0])
|
91 |
+
else:
|
92 |
+
H = round(img0.shape[0] / trf0[1,1])
|
93 |
+
W = round(img0.shape[1] / trf0[0,0])
|
94 |
+
|
95 |
+
if densify == 'cc':
|
96 |
+
assert False, 'todo'
|
97 |
+
elif densify in (True, 'full', 'convex'):
|
98 |
+
# recover true image0's shape
|
99 |
+
final = densify_corres( final, (H, W), full=(densify!='convex') )
|
100 |
+
else:
|
101 |
+
raise ValueError(f'Bad mode for {densify=}')
|
102 |
+
|
103 |
+
if dense_side == 'right': # undo temporary swap
|
104 |
+
final = final[:,[2,3,0,1]]
|
105 |
+
|
106 |
+
return final
|
107 |
+
|
108 |
+
|
109 |
+
def compute_graph(corres, max_dis=10, min_ang=90):
|
110 |
+
""" 4D distances (corres can only be connected to same scale)
|
111 |
+
using sparse matrices for efficiency
|
112 |
+
|
113 |
+
step1: build horizontal and vertical binning, binsize = max_dis
|
114 |
+
add in each bin all neighbor bins
|
115 |
+
step2: for each corres, we can intersect 2 bins to get a short list of candidates
|
116 |
+
step3: verify euclidean distance < maxdis (optional?)
|
117 |
+
"""
|
118 |
+
def bin_positions(pos):
|
119 |
+
# every corres goes into a single bin
|
120 |
+
bin_indices = np.int32(pos.clip(min=0) // max_dis) + 1
|
121 |
+
cols = np.arange(len(pos))
|
122 |
+
|
123 |
+
# add the cell before and the cell after, to handle border effects
|
124 |
+
res = csr_matrix((np.ones(len(bin_indices)*3,dtype=np.float32),
|
125 |
+
(np.r_[bin_indices-1, bin_indices, bin_indices+1], np.r_[cols,cols,cols])),
|
126 |
+
shape=(bin_indices.max()+2 if bin_indices.size else 1, len(pos)))
|
127 |
+
|
128 |
+
return res, bin_indices
|
129 |
+
|
130 |
+
# 1-hot matrices of shape = nbins x n_corres
|
131 |
+
x1_bins = bin_positions(corres[:,0])
|
132 |
+
y1_bins = bin_positions(corres[:,1])
|
133 |
+
x2_bins = bin_positions(corres[:,2])
|
134 |
+
y2_bins = bin_positions(corres[:,3])
|
135 |
+
|
136 |
+
def row_indices(ngh):
|
137 |
+
res = np.bincount(ngh.indptr[1:-1], minlength=ngh.indptr[-1])[:-1]
|
138 |
+
return res.cumsum()
|
139 |
+
|
140 |
+
def compute_dist( ngh, pts, scale=None ):
|
141 |
+
# pos from the second point
|
142 |
+
x_pos = pts[ngh.indices,0]
|
143 |
+
y_pos = pts[ngh.indices,1]
|
144 |
+
|
145 |
+
# subtract pos from the 1st point
|
146 |
+
rows = row_indices(ngh)
|
147 |
+
x_pos -= pts[rows, 0]
|
148 |
+
y_pos -= pts[rows, 1]
|
149 |
+
dis = np.sqrt(np.square(x_pos) + np.square(y_pos))
|
150 |
+
if scale is not None:
|
151 |
+
# there is a scale for each of the 2 pts, we encline to choose the worst one
|
152 |
+
dis *= (scale[rows] + scale[ngh.indices]) / 2 # so we use arithmetic instead of geometric mean
|
153 |
+
|
154 |
+
return normed(np.c_[x_pos, y_pos]), dis
|
155 |
+
|
156 |
+
def Rot( ngh, degrees ):
|
157 |
+
rows = row_indices(ngh)
|
158 |
+
rad = degrees * np.pi / 180
|
159 |
+
rad = (rad[rows] + rad[ngh.indices]) / 2 # average angle between 2 corres
|
160 |
+
cos, sin = np.cos(rad), np.sin(rad)
|
161 |
+
return np.float32(((cos, -sin), (sin,cos))).transpose(2,0,1)
|
162 |
+
|
163 |
+
def match(xbins, ybins, pt1, pt2, way):
|
164 |
+
xb, ixb = xbins
|
165 |
+
yb, iyb = ybins
|
166 |
+
|
167 |
+
# gets for each corres a list of potential matches
|
168 |
+
ngh = xb[ixb].multiply( yb[iyb] ) # shape = n_corres x n_corres
|
169 |
+
ngh = triu(ngh, k=1).tocsr() # remove mirrored matches
|
170 |
+
# ngh = matches of matches, shape = n_corres x n_corres
|
171 |
+
|
172 |
+
# verify locality and flow
|
173 |
+
vec1, d1 = compute_dist(ngh, pt1) # for each match, distance and orientation in img1
|
174 |
+
# assert d1.max()**0.5 < 2*max_dis*1.415, 'cannot be larger than 2 cells in diagonals, or there is a bug'+bb()
|
175 |
+
scale, rot = myF.decode_scale_rot(corres[:,5])
|
176 |
+
vec2, d2 = compute_dist(ngh, pt2, scale=scale**(-way))
|
177 |
+
ang = np.einsum('ik,ik->i', (vec1[:,None] @ Rot(ngh,way*rot))[:,0], vec2)
|
178 |
+
|
179 |
+
valid = (d1 <= max_dis) & (d2 <= max_dis) & (ang >= np.cos(min_ang*np.pi/180))
|
180 |
+
res = csr_matrix((valid, ngh.indices, ngh.indptr), shape=ngh.shape)
|
181 |
+
res.eliminate_zeros()
|
182 |
+
return res
|
183 |
+
|
184 |
+
# find all neihbors within each xy bin
|
185 |
+
ngh1 = match(x1_bins, y1_bins, corres[:,0:2], corres[:,2:4], way=+1)
|
186 |
+
ngh2 = match(x2_bins, y2_bins, corres[:,2:4], corres[:,0:2], way=-1).T
|
187 |
+
|
188 |
+
return ngh1 + ngh2 # union
|
189 |
+
|
190 |
+
|
191 |
+
def measure_connected_components(graph, dbg=()):
|
192 |
+
# compute connected components
|
193 |
+
nc, labels = csgraph.connected_components(graph, directed=False)
|
194 |
+
|
195 |
+
# filter and remove all small components
|
196 |
+
count = np.bincount(labels)
|
197 |
+
|
198 |
+
return count[labels]
|
199 |
+
|
200 |
+
def normed( mat ):
|
201 |
+
return mat / np.linalg.norm(mat, axis=-1, keepdims=True).clip(min=1e-16)
|
202 |
+
|
203 |
+
|
204 |
+
def densify_corres( corres, shape, full=True ):
|
205 |
+
from scipy.interpolate import LinearNDInterpolator
|
206 |
+
from scipy.spatial import cKDTree as KDTree
|
207 |
+
|
208 |
+
assert len(corres) > 3, 'Not enough corres for densification'
|
209 |
+
H, W = shape
|
210 |
+
|
211 |
+
interp = LinearNDInterpolator(corres[:,0:2], corres[:,2:4])
|
212 |
+
X, Y = np.mgrid[0:H, 0:W][::-1] # H x W, H x W
|
213 |
+
p1 = np.c_[X.ravel(), Y.ravel()]
|
214 |
+
p2 = interp(X, Y) # H x W x 2
|
215 |
+
|
216 |
+
p2 = p2.reshape(-1,2)
|
217 |
+
invalid = np.isnan(p2).any(axis=1)
|
218 |
+
|
219 |
+
if full:
|
220 |
+
# interpolate pixels outside of the convex hull
|
221 |
+
badp = p1[invalid]
|
222 |
+
tree = KDTree(corres[:,0:2])
|
223 |
+
_, nn = tree.query(badp, 3) # find 3 closest neighbors
|
224 |
+
corflow = corres[:,2:4] - corres[:,0:2]
|
225 |
+
p2.reshape(-1,2)[invalid] = corflow[nn].mean(axis=1) + p1[invalid]
|
226 |
+
else:
|
227 |
+
# remove nans, i.e. remove points outside of convex hull
|
228 |
+
p1, p2 = p1[~invalid], p2[~invalid]
|
229 |
+
|
230 |
+
# return correspondence field
|
231 |
+
return np.c_[p1, p2]
|
232 |
+
|
233 |
+
|
234 |
+
if __name__ == '__main__':
|
235 |
+
main(arg_parser().parse_args())
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib
|
2 |
+
numpy
|
3 |
+
scipy
|
4 |
+
torch==1.11.0
|
5 |
+
torchvision==0.12.0
|
run_ETH3D.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
import os, os.path as osp
|
7 |
+
from tqdm import tqdm
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
SEQUENCES = [ 'lakeside', 'sand_box', 'storage_room', 'storage_room_2', 'tunnel',
|
11 |
+
'delivery_area', 'electro', 'forest', 'playground', 'terrains']
|
12 |
+
|
13 |
+
RATES = [3, 5, 7, 9, 11, 13, 15]
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
import argparse
|
17 |
+
parser = argparse.ArgumentParser('PUMP evaluation script for the ETH3D dataset')
|
18 |
+
|
19 |
+
parser.add_argument('--root', default='datasets/eth3d')
|
20 |
+
parser.add_argument('--output', default='results/eth3d')
|
21 |
+
|
22 |
+
parser.add_argument('--just-print', action='store_true', help='just print commands')
|
23 |
+
return parser.parse_args()
|
24 |
+
|
25 |
+
|
26 |
+
def main( args ):
|
27 |
+
run_pump(args) and run_eval(args)
|
28 |
+
|
29 |
+
|
30 |
+
def run_pump(args):
|
31 |
+
done = True
|
32 |
+
for img1, img2 in tqdm(list_eth3d_pairs()):
|
33 |
+
output_path = osp.join(args.output, img1, img2+'.corres')
|
34 |
+
if osp.isfile(output_path): continue
|
35 |
+
|
36 |
+
done = False
|
37 |
+
_exec(f'''python test_multiscale_recursive.py
|
38 |
+
--img1 {osp.join(args.root,img1)}
|
39 |
+
--img2 {osp.join(args.root,img2)}
|
40 |
+
--max-scale 1.5
|
41 |
+
--desc PUMP
|
42 |
+
--post-filter "densify=True,dense_side='right'"
|
43 |
+
--output {output_path}''')
|
44 |
+
|
45 |
+
return done
|
46 |
+
|
47 |
+
|
48 |
+
def run_eval( args ):
|
49 |
+
for rate in RATES:
|
50 |
+
mean_aepe_per_rate = 0
|
51 |
+
|
52 |
+
for seq in SEQUENCES:
|
53 |
+
pairs = np.load(osp.join(args.root, 'info_ETH3D_files', f'{seq}_every_5_rate_of_{rate}'), allow_pickle=True)
|
54 |
+
|
55 |
+
mean_aepe_per_seq = 0
|
56 |
+
for pair in pairs:
|
57 |
+
img1, img2 = pair['source_image'], pair['target_image']
|
58 |
+
Ys, Xs, Yt, Xt = [np.float32(pair[k]) for k in 'Ys Xs Yt Xt'.split()]
|
59 |
+
|
60 |
+
corres_path = osp.join(args.output, img1, img2+'.corres')
|
61 |
+
corres = np.load(corres_path, allow_pickle=True)['corres']
|
62 |
+
|
63 |
+
# extract estimated and target flow
|
64 |
+
W, H = np.int32(corres[-1, 2:4] + 1)
|
65 |
+
flow = (corres[:,0:2] - corres[:,2:4]).reshape(H, W, 2)
|
66 |
+
iYt, iXt = np.int32(np.round(Yt)), np.int32(np.round(Xt))
|
67 |
+
if 'correct way':
|
68 |
+
gt_targets = np.c_[Xs - Xt, Ys - Yt]
|
69 |
+
est_targets = flow[iYt, iXt]
|
70 |
+
elif 'GLU-Net way (somewhat inaccurate because of overlapping points in the mask)':
|
71 |
+
mask = np.zeros((H,W), dtype=bool)
|
72 |
+
mask[iYt, iXt] = True
|
73 |
+
gt_flow = np.full((H,W,2), np.nan, dtype=np.float32)
|
74 |
+
gt_flow[iYt, iXt, 0] = Xs - Xt
|
75 |
+
gt_flow[iYt, iXt, 1] = Ys - Yt
|
76 |
+
gt_targets = gt_flow[mask]
|
77 |
+
est_targets = flow[mask]
|
78 |
+
|
79 |
+
# compute end-point error
|
80 |
+
aepe = np.linalg.norm(est_targets - gt_targets, axis=-1).mean()
|
81 |
+
mean_aepe_per_seq += aepe
|
82 |
+
|
83 |
+
mean_aepe_per_seq /= len(pairs)
|
84 |
+
mean_aepe_per_rate += mean_aepe_per_seq
|
85 |
+
print(f'mean AEPE for {rate=} {seq=}:', mean_aepe_per_seq)
|
86 |
+
|
87 |
+
print(f'>> mean AEPE for {rate=}:', mean_aepe_per_rate / len(SEQUENCES))
|
88 |
+
|
89 |
+
|
90 |
+
def list_eth3d_pairs():
|
91 |
+
path = osp.join(args.root, 'info_ETH3D_files', 'list_pairs.txt')
|
92 |
+
try:
|
93 |
+
lines = open(path).read().splitlines()
|
94 |
+
except OSError:
|
95 |
+
lines = []
|
96 |
+
for seq in SEQUENCES:
|
97 |
+
for rate in RATES:
|
98 |
+
pairs = np.load(osp.join(args.root, 'info_ETH3D_files', f'{seq}_every_5_rate_of_{rate}'), allow_pickle=True)
|
99 |
+
for pair in pairs:
|
100 |
+
lines.append(pair['source_image'] + ' ' + pair['target_image'])
|
101 |
+
open(path, 'w').write('\n'.join(lines))
|
102 |
+
|
103 |
+
pairs = [line.split() for line in lines if line[0] != '#']
|
104 |
+
return pairs
|
105 |
+
|
106 |
+
|
107 |
+
def _exec(cmd):
|
108 |
+
# strip & remove \n
|
109 |
+
cmd = ' '.join(cmd.split())
|
110 |
+
if args.just_print:
|
111 |
+
print(cmd)
|
112 |
+
else:
|
113 |
+
os.system(cmd)
|
114 |
+
|
115 |
+
|
116 |
+
if __name__ == '__main__':
|
117 |
+
args = parse_args()
|
118 |
+
main( args )
|
test_multiscale.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
from itertools import starmap
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
import test_singlescale as tss
|
13 |
+
from core import functional as myF
|
14 |
+
from tools.common import todevice, cpu
|
15 |
+
from tools.viz import dbgfig, show_correspondences
|
16 |
+
|
17 |
+
|
18 |
+
def arg_parser():
|
19 |
+
parser = tss.arg_parser()
|
20 |
+
parser.set_defaults(levels = 0, verbose=0)
|
21 |
+
|
22 |
+
parser.add_argument('--min-scale', type=float, default=None, help='min scale ratio')
|
23 |
+
parser.add_argument('--max-scale', type=float, default=4, help='max scale ratio')
|
24 |
+
|
25 |
+
parser.add_argument('--min-rot', type=float, default=None, help='min rotation (in degrees) in [-180,180]')
|
26 |
+
parser.add_argument('--max-rot', type=float, default=0, help='max rotation (in degrees) in [0,180]')
|
27 |
+
parser.add_argument('--crop-rot', action='store_true', help='crop rotated image to prevent memory blow-up')
|
28 |
+
parser.add_argument('--rot-step', type=int, default=45, help='rotation step (in degrees)')
|
29 |
+
|
30 |
+
parser.add_argument('--no-swap', type=int, default=1, nargs='?', const=0, choices=[1,0,-1], help='if 0, img1 will have keypoints on a grid')
|
31 |
+
parser.add_argument('--same-levels', action='store_true', help='use the same number of pyramid levels for all scales')
|
32 |
+
|
33 |
+
parser.add_argument('--merge', choices='torch cpu cuda'.split(), default='cpu')
|
34 |
+
return parser
|
35 |
+
|
36 |
+
|
37 |
+
class MultiScalePUMP (nn.Module):
|
38 |
+
""" DeepMatching that loops over all possible {scale x rotation} combinations.
|
39 |
+
"""
|
40 |
+
def __init__(self, matcher,
|
41 |
+
min_scale=1,
|
42 |
+
max_scale=1,
|
43 |
+
max_rot=0,
|
44 |
+
min_rot=0,
|
45 |
+
rot_step=45,
|
46 |
+
swap_mode=1,
|
47 |
+
same_levels=False,
|
48 |
+
crop_rot=False):
|
49 |
+
super().__init__()
|
50 |
+
min_scale = min_scale or 1/max_scale
|
51 |
+
min_rot = min_rot or -max_rot
|
52 |
+
assert 0.1 <= min_scale <= max_scale <= 10
|
53 |
+
assert -180 <= min_rot <= max_rot <= 180
|
54 |
+
self.matcher = matcher
|
55 |
+
self.matcher.crop_rot = crop_rot
|
56 |
+
|
57 |
+
self.min_sc = min_scale
|
58 |
+
self.max_sc = max_scale
|
59 |
+
self.min_rot = min_rot
|
60 |
+
self.max_rot = max_rot
|
61 |
+
self.rot_step = rot_step
|
62 |
+
self.swap_mode = swap_mode
|
63 |
+
self.merge_device = None
|
64 |
+
self.same_levels = same_levels
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def forward(self, img1, img2, dbg=()):
|
68 |
+
img1, sca1 = img1 if isinstance(img1, tuple) else (img1, torch.eye(3, device=img1.device))
|
69 |
+
img2, sca2 = img2 if isinstance(img2, tuple) else (img2, torch.eye(3, device=img2.device))
|
70 |
+
|
71 |
+
# prepare correspondences accumulators
|
72 |
+
if self.same_levels: # limit number of levels
|
73 |
+
self.matcher.levels = self._find_max_levels(img1,img2)
|
74 |
+
elif self.matcher.levels == 0:
|
75 |
+
max_psize = int(min(np.mean(img1.shape[-2:]), np.mean(img2.shape[-2:])))
|
76 |
+
self.matcher.levels = int(np.log2(max_psize / self.matcher.pixel_desc.get_atomic_patch_size()))
|
77 |
+
|
78 |
+
all_corres = (self._make_accu(img1), self._make_accu(img2))
|
79 |
+
|
80 |
+
for scale, ang, code, swap, swapped, (scimg1, scimg2) in self._enum_scaled_pairs(img1, img2):
|
81 |
+
print(f"processing {scale=:g} x {ang=} {['','(swapped)'][swapped]} ({code=})...")
|
82 |
+
|
83 |
+
# compute correspondences with rotated+scaled image
|
84 |
+
corres, rots = self.process_one_scale(swapped, *[scimg1,scimg2], dbg=dbg)
|
85 |
+
if dbgfig('corres-ms', dbg): viz_correspondences(img1, img2, *corres, fig='last')
|
86 |
+
|
87 |
+
# merge correspondences in the reference frame
|
88 |
+
self.merge_corres( corres, rots, all_corres, code )
|
89 |
+
|
90 |
+
# final intersection
|
91 |
+
corres = self.reciprocal( *all_corres )
|
92 |
+
return myF.affmul(todevice((sca1,sca2),corres.device), corres) # rescaling to original image scale
|
93 |
+
|
94 |
+
def process_one_scale(self, swapped, *imgs, dbg=()):
|
95 |
+
return unswap(self.matcher(*imgs, ret='raw', dbg=dbg), swapped)
|
96 |
+
|
97 |
+
def _find_max_levels(self, img1, img2):
|
98 |
+
min_levels = self.matcher.levels or 999
|
99 |
+
for _, _, code, _, _, (img1, img2) in self._enum_scaled_pairs(img1, img2):
|
100 |
+
# first level when a parent dont have children: gap >= min(shape), with gap = 2**(level-2)
|
101 |
+
img1_levels = ceil(np.log2(min(img1[0].shape[-2:])) - 1)
|
102 |
+
# first level when img2's shape becomes smaller than self.min_shape, with shape = min(shape) / 2**level
|
103 |
+
img2_levels = ceil(np.log2(min(img2[0].shape[-2:]) / self.matcher.min_shape))
|
104 |
+
# print(f'predicted levels for {code=}:\timg1 --> {img1_levels},\timg2 --> {img2_levels} levels')
|
105 |
+
min_levels = min(min_levels, img1_levels, img2_levels)
|
106 |
+
return min_levels
|
107 |
+
|
108 |
+
def merge_corres(self, corres, rots, all_corres, code):
|
109 |
+
" rot : reference --> rotated "
|
110 |
+
self.merge_one_side( corres[0], slice(0,2), rots[0], all_corres[0], code )
|
111 |
+
self.merge_one_side( corres[1], slice(2,4), rots[1], all_corres[1], code )
|
112 |
+
|
113 |
+
def merge_one_side(self, corres, sel, trf, all_corres, code ):
|
114 |
+
pos, scores = corres
|
115 |
+
grid, accu = all_corres
|
116 |
+
accu = accu.view(-1, 6)
|
117 |
+
|
118 |
+
# compute 4-nn in transformed image for each grid point
|
119 |
+
best4 = torch.cdist(pos[:,sel].float(), grid).topk(4, dim=0, largest=False)
|
120 |
+
# best4.shape = (4, len(grid))
|
121 |
+
|
122 |
+
# update if score is better AND distance less than 2x best dist
|
123 |
+
scale = float(torch.sqrt(torch.det(trf))) # == scale (with scale >= 1)
|
124 |
+
dist_max = 8*scale - 1e-7 # 2x the distance between contiguous patches
|
125 |
+
|
126 |
+
close_enough = (best4.values <= 2*best4.values[0:1]) & (best4.values < dist_max)
|
127 |
+
neg_inf = torch.tensor(-np.inf, device=scores.device)
|
128 |
+
best_score = torch.where(close_enough, scores.ravel()[best4.indices], neg_inf).max(dim=0)
|
129 |
+
is_better = best_score.values > accu[:,4].ravel()
|
130 |
+
|
131 |
+
accu[is_better,0:4] = pos[best4.indices[best_score.indices,torch.arange(len(grid))][is_better]]
|
132 |
+
accu[is_better,4] = best_score.values[is_better]
|
133 |
+
accu[is_better,5] = code
|
134 |
+
|
135 |
+
def reciprocal(self, corres1, corres2 ):
|
136 |
+
grid1, corres1 = cpu(corres1)
|
137 |
+
grid2, corres2 = cpu(corres2)
|
138 |
+
|
139 |
+
(H1, W1), (H2, W2) = grid1[-1]+1, grid2[-1]+1
|
140 |
+
pos1 = corres1[:,:,0:4].view(-1,4)
|
141 |
+
pos2 = corres2[:,:,0:4].view(-1,4)
|
142 |
+
|
143 |
+
to_int = torch.tensor((W1*H2*W2, H2*W2, W2, 1), dtype=torch.float32)
|
144 |
+
inter1 = myF.intersection(pos1@to_int, pos2@to_int)
|
145 |
+
return corres1.view(-1,6)[inter1]
|
146 |
+
|
147 |
+
def _enum_scales(self):
|
148 |
+
for i in range(-100,101):
|
149 |
+
scale = 2**(i/2)
|
150 |
+
# if i != -2: continue
|
151 |
+
if self.min_sc <= scale <= self.max_sc:
|
152 |
+
yield i,scale
|
153 |
+
|
154 |
+
def _enum_rotations(self):
|
155 |
+
for i in range(-180//self.rot_step, 180//self.rot_step):
|
156 |
+
rot = i * self.rot_step
|
157 |
+
if self.min_rot <= rot <= self.max_rot:
|
158 |
+
yield i,-rot
|
159 |
+
|
160 |
+
def _enum_scaled_pairs(self, img1, img2):
|
161 |
+
for s, scale in self._enum_scales():
|
162 |
+
(i1,sca1), (i2,sca2) = starmap(downsample_img, [(img1, min(scale, 1)), (img2, min(1/scale, 1))])
|
163 |
+
# set bigger image as the first one
|
164 |
+
size1 = min(i1.shape[-2:])
|
165 |
+
size2 = min(i2.shape[-2:])
|
166 |
+
swapped = size1*self.swap_mode < size2*self.swap_mode
|
167 |
+
swap = (1 - 2*swapped) # swapped ==> swap = -1
|
168 |
+
if swapped:
|
169 |
+
(i1,sca1), (i2,sca2) = (i2,sca2), (i1,sca1)
|
170 |
+
|
171 |
+
for r, ang in self._enum_rotations():
|
172 |
+
code = myF.encode_scale_rot(scale, ang)
|
173 |
+
trf1 = (sca1, swap*ang) if ang != 0 else sca1
|
174 |
+
yield scale, ang, code, swap, swapped, ((i1,trf1), (i2,sca2))
|
175 |
+
|
176 |
+
def _make_accu(self, img):
|
177 |
+
C, H, W = img.shape
|
178 |
+
step = self.matcher.pixel_desc.get_atomic_patch_size() // 2
|
179 |
+
h = step//2 - 1
|
180 |
+
accu = img.new_zeros(((H+h)//step, (W+h)//step, 6), dtype=torch.float32, device=self.merge_device or img.device)
|
181 |
+
grid = step * myF.mgrid(accu[:,:,0], device=img.device) + (step//2)
|
182 |
+
return grid, accu
|
183 |
+
|
184 |
+
|
185 |
+
def downsample_img(img, scale=0):
|
186 |
+
assert scale <= 1
|
187 |
+
img, trf = img if isinstance(img, tuple) else (img, torch.eye(3, device=img.device))
|
188 |
+
if scale == 1: return img, trf
|
189 |
+
|
190 |
+
assert img.dtype == torch.uint8
|
191 |
+
trf = trf.clone() # dont modify inplace
|
192 |
+
trf[:2,:2] /= scale
|
193 |
+
while scale <= 0.5:
|
194 |
+
img = F.avg_pool2d(img[None].float(), 2, stride=2, count_include_pad=False)[0]
|
195 |
+
scale *= 2
|
196 |
+
if scale != 1:
|
197 |
+
img = F.interpolate(img[None].float(), scale_factor=scale, mode='bicubic', align_corners=False, recompute_scale_factor=False).clamp(min=0, max=255)[0]
|
198 |
+
return img.byte(), trf # scaled --> pxl
|
199 |
+
|
200 |
+
|
201 |
+
def ceil(i):
|
202 |
+
return int(np.ceil(i))
|
203 |
+
|
204 |
+
def unswap( corres, swapped ):
|
205 |
+
swap = -1 if swapped else 1
|
206 |
+
corres, rots = corres
|
207 |
+
corres = corres[::swap]
|
208 |
+
rots = rots[::swap]
|
209 |
+
if swapped:
|
210 |
+
for pos, _ in corres:
|
211 |
+
pos[:,0:4] = pos[:,[2,3,0,1]].clone()
|
212 |
+
return corres, rots
|
213 |
+
|
214 |
+
|
215 |
+
def demultiplex_img_trf(self, img, force=False):
|
216 |
+
""" img is:
|
217 |
+
- an image
|
218 |
+
- a tuple (image, trf)
|
219 |
+
- a tuple (image, (cur_trf, trf_todo))
|
220 |
+
In any case, trf: cur_pix --> old_pix
|
221 |
+
"""
|
222 |
+
img, trf = img if isinstance(img, tuple) else (img, torch.eye(3, device=img.device))
|
223 |
+
|
224 |
+
if isinstance(trf, tuple):
|
225 |
+
trf, todo = trf
|
226 |
+
if isinstance(todo, (int,float)): # pure rotation
|
227 |
+
img, trf = myF.rotate_img((img,trf), angle=todo, crop=self.crop_rot)
|
228 |
+
else:
|
229 |
+
img = myF.apply_trf_to_img(todo, img)
|
230 |
+
trf = trf @ todo
|
231 |
+
return img, trf
|
232 |
+
|
233 |
+
|
234 |
+
class Main (tss.Main):
|
235 |
+
@staticmethod
|
236 |
+
def get_options( args ):
|
237 |
+
return dict(max_scale=args.max_scale, min_scale=args.min_scale,
|
238 |
+
max_rot=args.max_rot, min_rot=args.min_rot, rot_step=args.rot_step,
|
239 |
+
swap_mode=args.no_swap, same_levels=args.same_levels, crop_rot=args.crop_rot)
|
240 |
+
|
241 |
+
@staticmethod
|
242 |
+
def tune_matcher( args, matcher, device ):
|
243 |
+
if device == 'cpu':
|
244 |
+
args.merge = 'cpu'
|
245 |
+
|
246 |
+
if args.merge == 'cpu': type(matcher).merge_corres = myF.merge_corres; matcher.merge_device = 'cpu'
|
247 |
+
elif args.merge == 'cuda': type(matcher).merge_corres = myF.merge_corres
|
248 |
+
|
249 |
+
return matcher.to(device)
|
250 |
+
|
251 |
+
@staticmethod
|
252 |
+
def build_matcher( args, device):
|
253 |
+
# get a normal matcher
|
254 |
+
matcher = tss.Main.build_matcher(args, device)
|
255 |
+
type(matcher).demultiplex_img_trf = demultiplex_img_trf # update transformer
|
256 |
+
|
257 |
+
options = Main.get_options(args)
|
258 |
+
return Main.tune_matcher(args, MultiScalePUMP(matcher, **options), device)
|
259 |
+
|
260 |
+
|
261 |
+
if __name__ == '__main__':
|
262 |
+
Main().run_from_args(arg_parser().parse_args())
|
test_multiscale_recursive.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
import test_singlescale as ss
|
6 |
+
import test_singlescale_recursive as ssr
|
7 |
+
import test_multiscale as ms
|
8 |
+
|
9 |
+
def arg_parser():
|
10 |
+
parser = ssr.arg_parser(ms.arg_parser())
|
11 |
+
return parser
|
12 |
+
|
13 |
+
class Main (ms.Main):
|
14 |
+
@staticmethod
|
15 |
+
def build_matcher(args, device):
|
16 |
+
# get a single-scale recursive matcher
|
17 |
+
matcher = ssr.Main.build_matcher(args, device)
|
18 |
+
type(matcher).demultiplex_img_trf = ms.demultiplex_img_trf # update transformer
|
19 |
+
|
20 |
+
options = Main.get_options(args)
|
21 |
+
return Main.tune_matcher(args, ms.MultiScalePUMP(matcher, **options), device).to(device)
|
22 |
+
|
23 |
+
if __name__ == '__main__':
|
24 |
+
Main().run_from_args(arg_parser().parse_args())
|
test_singlescale.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from core import functional as myF
|
12 |
+
from core.pixel_desc import PixelDesc
|
13 |
+
from tools.common import mkdir_for, todevice, cudnn_benchmark, nparray, image, image_with_trf
|
14 |
+
from tools.viz import dbgfig, show_correspondences
|
15 |
+
|
16 |
+
|
17 |
+
def arg_parser():
|
18 |
+
import argparse
|
19 |
+
parser = argparse.ArgumentParser('SingleScalePUMP on GPU with PyTorch')
|
20 |
+
|
21 |
+
parser.add_argument('--img1', required=True, help='path to img1')
|
22 |
+
parser.add_argument('--img2', required=True, help='path to img2')
|
23 |
+
parser.add_argument('--resize', type=int, default=512, nargs='+', help='prior downsize of img1 and img2')
|
24 |
+
|
25 |
+
parser.add_argument('--output', default=None, help='output path for correspondences')
|
26 |
+
|
27 |
+
parser.add_argument('--levels', type=int, default=99, help='number of pyramid levels')
|
28 |
+
parser.add_argument('--min-shape', type=int, default=5, help='minimum size of corr maps')
|
29 |
+
parser.add_argument('--nlpow', type=float, default=1.5, help='non-linear activation power in [1,2]')
|
30 |
+
parser.add_argument('--border', type=float, default=0.9, help='border invariance level in [0,1]')
|
31 |
+
parser.add_argument('--dtype', default='float16', choices='float16 float32 float64'.split())
|
32 |
+
|
33 |
+
parser.add_argument('--desc', default='PUMP-stytrf', help='checkpoint name')
|
34 |
+
parser.add_argument('--first-level', choices='torch'.split(), default='torch')
|
35 |
+
parser.add_argument('--activation', choices='torch'.split(), default='torch')
|
36 |
+
parser.add_argument('--forward', choices='torch cuda cuda-lowmem'.split(), default='cuda-lowmem')
|
37 |
+
parser.add_argument('--backward', choices='python torch cuda'.split(), default='cuda')
|
38 |
+
parser.add_argument('--reciprocal', choices='cpu cuda'.split(), default='cpu')
|
39 |
+
|
40 |
+
parser.add_argument('--post-filter', default=None, const=True, nargs='?', help='post-filtering (See post_filter.py)')
|
41 |
+
|
42 |
+
parser.add_argument('--verbose', type=int, default=0, help='verbosity')
|
43 |
+
parser.add_argument('--device', default='cuda', help='gpu device')
|
44 |
+
parser.add_argument('--dbg', nargs='*', default=(), help='debug options')
|
45 |
+
|
46 |
+
return parser
|
47 |
+
|
48 |
+
|
49 |
+
class SingleScalePUMP (nn.Module):
|
50 |
+
def __init__(self, levels = 9, nlpow = 1.4, cutoff = 1,
|
51 |
+
border_inv=0.9, min_shape=5, renorm=(),
|
52 |
+
pixel_desc = None, dtype = torch.float32,
|
53 |
+
verbose = True ):
|
54 |
+
super().__init__()
|
55 |
+
self.levels = levels
|
56 |
+
self.min_shape = min_shape
|
57 |
+
self.nlpow = nlpow
|
58 |
+
self.border_inv = border_inv
|
59 |
+
assert pixel_desc, 'Requires a pixel descriptor'
|
60 |
+
self.pixel_desc = pixel_desc.configure(self)
|
61 |
+
self.dtype = dtype
|
62 |
+
self.verbose = verbose
|
63 |
+
|
64 |
+
@torch.no_grad()
|
65 |
+
def forward(self, img1, img2, ret='corres', dbg=()):
|
66 |
+
with cudnn_benchmark(False):
|
67 |
+
# compute descriptors
|
68 |
+
(img1, img2), pixel_descs, trfs = self.extract_descs(img1, img2, dtype=self.dtype)
|
69 |
+
|
70 |
+
# backward and forward passes
|
71 |
+
pixel_corr = self.first_level(*pixel_descs, dbg=dbg)
|
72 |
+
pixel_corr = self.backward_pass(self.forward_pass(pixel_corr, dbg=dbg), dbg=dbg)
|
73 |
+
|
74 |
+
# recover correspondences
|
75 |
+
corres = myF.best_correspondences( pixel_corr )
|
76 |
+
|
77 |
+
if dbgfig('corres', dbg): viz_correspondences(img1[0], img2[0], *corres, fig='last')
|
78 |
+
corres = [(myF.affmul(trfs,pos),score) for pos, score in corres] # rectify scaling etc.
|
79 |
+
if ret == 'raw': return corres, trfs
|
80 |
+
return self.reciprocal(*corres)
|
81 |
+
|
82 |
+
def extract_descs(self, img1, img2, dtype=None):
|
83 |
+
img1, sca1 = self.demultiplex_img_trf(img1)
|
84 |
+
img2, sca2 = self.demultiplex_img_trf(img2)
|
85 |
+
desc1, trf1 = self.pixel_desc(img1)
|
86 |
+
desc2, trf2 = self.pixel_desc(img2)
|
87 |
+
return (img1, img2), (desc1.type(dtype), desc2.type(dtype)), (sca1@trf1, sca2@trf2)
|
88 |
+
|
89 |
+
def demultiplex_img_trf(self, img, **kw):
|
90 |
+
return img if isinstance(img, tuple) else (img, torch.eye(3, device=img.device))
|
91 |
+
|
92 |
+
def forward_pass(self, pixel_corr, dbg=()):
|
93 |
+
weights = None
|
94 |
+
if isinstance(pixel_corr, tuple):
|
95 |
+
pixel_corr, weights = pixel_corr
|
96 |
+
|
97 |
+
# first-level with activation
|
98 |
+
if self.verbose: print(f' Pyramid level {0} shape={tuple(pixel_corr.shape)}')
|
99 |
+
pyramid = [ self.activation(0,pixel_corr) ]
|
100 |
+
if dbgfig(f'corr0', dbg): viz_correlation_maps(*from_stack('img1','img2'), pyramid[0], fig='last')
|
101 |
+
|
102 |
+
for level in range(1, self.levels+1):
|
103 |
+
upper, weights = self.forward_level(level, pyramid[-1], weights)
|
104 |
+
if weights.sum() == 0: break # img1 has become too small
|
105 |
+
|
106 |
+
# activation
|
107 |
+
pyramid.append( self.activation(level,upper) )
|
108 |
+
|
109 |
+
if self.verbose: print(f' Pyramid level {level} shape={tuple(upper.shape)}')
|
110 |
+
if dbgfig(f'corr{level}', dbg): viz_correlation_maps(*from_stack('img1','img2'), upper, level=level, fig='last')
|
111 |
+
if min(upper.shape[-2:]) <= self.min_shape: break # img2 has become too small
|
112 |
+
|
113 |
+
return pyramid
|
114 |
+
|
115 |
+
def forward_level(self, level, corr, weights):
|
116 |
+
# max-pooling
|
117 |
+
pooled = F.max_pool2d(corr, 3, padding=1, stride=2)
|
118 |
+
|
119 |
+
# sparse conv
|
120 |
+
return myF.sparse_conv(level, pooled, weights, norm=self.border_inv)
|
121 |
+
|
122 |
+
def backward_pass(self, pyramid, dbg=()):
|
123 |
+
# same than forward in reverse order
|
124 |
+
for level in range(len(pyramid)-1, 0, -1):
|
125 |
+
lower = self.backward_level(level, pyramid)
|
126 |
+
# assert not torch.isnan(lower).any(), bb()
|
127 |
+
if self.verbose: print(f' Pyramid level {level-1} shape={tuple(lower.shape)}')
|
128 |
+
del pyramid[-1] # free memory
|
129 |
+
if dbgfig(f'corr{level}-bw', dbg): viz_correlation_maps(img1, img2, lower, fig='last')
|
130 |
+
return pyramid[0]
|
131 |
+
|
132 |
+
def backward_level(self, level, pyramid):
|
133 |
+
# reverse sparse-coonv
|
134 |
+
pooled = myF.sparse_conv(level, pyramid[level], reverse=True)
|
135 |
+
|
136 |
+
# reverse max-pool and add to lower level
|
137 |
+
return myF.max_unpool(pooled, pyramid[level-1])
|
138 |
+
|
139 |
+
def activation(self, level, corr):
|
140 |
+
assert 1 <= self.nlpow <= 3
|
141 |
+
corr.clamp_(min=0).pow_(self.nlpow)
|
142 |
+
return corr
|
143 |
+
|
144 |
+
def first_level(self, desc1, desc2, dbg=()):
|
145 |
+
assert desc1.ndim == desc2.ndim == 4
|
146 |
+
assert len(desc1) == len(desc2) == 1, "not implemented"
|
147 |
+
H1, W1 = desc1.shape[-2:]
|
148 |
+
H2, W2 = desc2.shape[-2:]
|
149 |
+
|
150 |
+
patches = F.unfold(desc1, 4, stride=4) # C*4*4, H1*W1//16
|
151 |
+
B, C, N = patches.shape
|
152 |
+
# rearrange(patches, 'B (C Kh Kw) H1W1 -> B H1W1 C Kh Kw', Kh=4, Kw=4)
|
153 |
+
patches = patches.permute(0, 2, 1).view(B, H1W1, C//16, 4, 4)
|
154 |
+
|
155 |
+
corr, norms = myF.normalized_corr(patches[0], desc2[0], ret_norms=True)
|
156 |
+
if dbgfig('ncc',dbg):
|
157 |
+
for j in range(0,len(corr),9):
|
158 |
+
for i in range(9):
|
159 |
+
pl.subplot(3,3,i+1).cla()
|
160 |
+
i += j
|
161 |
+
pl.imshow(corr[i], vmin=0.9, vmax=1)
|
162 |
+
pl.plot(2+(i%16)*4, 2+(i//16)*4,'xr', ms=10)
|
163 |
+
bb()
|
164 |
+
return corr.view(H1//4, W1//4, H2+1, W2+1), (norms.view(H1//4, W1//4)>0).float()
|
165 |
+
|
166 |
+
def reciprocal(self, corres1, corres2 ):
|
167 |
+
corres1, corres2 = todevice(corres1, 'cpu'), todevice(corres2, 'cpu')
|
168 |
+
return myF.reciprocal(self, corres1, corres2)
|
169 |
+
|
170 |
+
|
171 |
+
class Main:
|
172 |
+
def __init__(self):
|
173 |
+
self.post_filtering = False
|
174 |
+
|
175 |
+
def run_from_args(self, args):
|
176 |
+
device = args.device
|
177 |
+
self.matcher = self.build_matcher(args, device)
|
178 |
+
if args.post_filter:
|
179 |
+
self.post_filtering = {} if args.post_filter is True else eval(f'dict({args.post_filter})')
|
180 |
+
|
181 |
+
corres = self(*self.load_images(args, device), dbg=set(args.dbg))
|
182 |
+
|
183 |
+
if args.output:
|
184 |
+
self.save_output( args.output, corres )
|
185 |
+
|
186 |
+
def run_from_args_with_images(self, img1, img2, args):
|
187 |
+
device = args.device
|
188 |
+
self.matcher = self.build_matcher(args, device)
|
189 |
+
if args.post_filter:
|
190 |
+
self.post_filtering = {} if args.post_filter is True else eval(f'dict({args.post_filter})')
|
191 |
+
|
192 |
+
if isinstance(args.resize, int): # user can provide 2 separate sizes for each image
|
193 |
+
args.resize = (args.resize, args.resize)
|
194 |
+
|
195 |
+
if len(args.resize) == 1:
|
196 |
+
args.resize = 2 * args.resize
|
197 |
+
|
198 |
+
images = []
|
199 |
+
for imgx, size in zip([img1, img2], args.resize):
|
200 |
+
img = torch.from_numpy(np.array(imgx.convert('RGB'))).permute(2,0,1).to(device)
|
201 |
+
img = myF.imresize(img, size)
|
202 |
+
images.append( img )
|
203 |
+
|
204 |
+
corres = self(*images, dbg=set(args.dbg))
|
205 |
+
|
206 |
+
if args.output:
|
207 |
+
self.save_output( args.output, corres )
|
208 |
+
|
209 |
+
return corres
|
210 |
+
|
211 |
+
|
212 |
+
@staticmethod
|
213 |
+
def get_options( args ):
|
214 |
+
# configure the pipeline
|
215 |
+
pixel_desc = PixelDesc(path=f'checkpoints/{args.desc}.pt')
|
216 |
+
return dict(levels=args.levels, min_shape=args.min_shape, border_inv=args.border, nlpow=args.nlpow,
|
217 |
+
pixel_desc=pixel_desc, dtype=eval(f'torch.{args.dtype}'), verbose=args.verbose)
|
218 |
+
|
219 |
+
@staticmethod
|
220 |
+
def tune_matcher( args, matcher, device ):
|
221 |
+
if device == 'cpu':
|
222 |
+
matcher.dtype = torch.float32
|
223 |
+
args.forward = 'torch'
|
224 |
+
args.backward = 'torch'
|
225 |
+
args.reciprocal = 'cpu'
|
226 |
+
|
227 |
+
if args.forward == 'cuda': type(matcher).forward_level = myF.forward_cuda
|
228 |
+
if args.forward == 'cuda-lowmem':type(matcher).forward_level = myF.forward_cuda_lowmem
|
229 |
+
if args.backward == 'python': type(matcher).backward_pass = legacy.backward_python
|
230 |
+
if args.backward == 'cuda': type(matcher).backward_level = myF.backward_cuda
|
231 |
+
if args.reciprocal == 'cuda': type(matcher).reciprocal = myF.reciprocal
|
232 |
+
|
233 |
+
return matcher.to(device)
|
234 |
+
|
235 |
+
@staticmethod
|
236 |
+
def build_matcher(args, device):
|
237 |
+
options = Main.get_options(args)
|
238 |
+
matcher = SingleScalePUMP(**options)
|
239 |
+
return Main.tune_matcher(args, matcher, device)
|
240 |
+
|
241 |
+
def __call__(self, *imgs, dbg=()):
|
242 |
+
corres = self.matcher( *imgs, dbg=dbg).cpu().numpy()
|
243 |
+
if self.post_filtering is not False:
|
244 |
+
corres = self.post_filter( imgs, corres )
|
245 |
+
|
246 |
+
if 'print' in dbg: print(corres)
|
247 |
+
if dbgfig('viz',dbg): show_correspondences(*imgs, corres)
|
248 |
+
return corres
|
249 |
+
|
250 |
+
@staticmethod
|
251 |
+
def load_images( args, device='cpu' ):
|
252 |
+
def read_image(impath):
|
253 |
+
try:
|
254 |
+
from torchvision.io.image import read_image, ImageReadMode
|
255 |
+
return read_image(impath, mode=ImageReadMode.RGB)
|
256 |
+
except RuntimeError:
|
257 |
+
from PIL import Image
|
258 |
+
return torch.from_numpy(np.array(Image.open(impath).convert('RGB'))).permute(2,0,1)
|
259 |
+
|
260 |
+
if isinstance(args.resize, int): # user can provide 2 separate sizes for each image
|
261 |
+
args.resize = (args.resize, args.resize)
|
262 |
+
|
263 |
+
if len(args.resize) == 1:
|
264 |
+
args.resize = 2 * args.resize
|
265 |
+
|
266 |
+
images = []
|
267 |
+
for impath, size in zip([args.img1, args.img2], args.resize):
|
268 |
+
img = read_image(impath).to(device)
|
269 |
+
img = myF.imresize(img, size)
|
270 |
+
images.append( img )
|
271 |
+
return images
|
272 |
+
|
273 |
+
def post_filter(self, imgs, corres ):
|
274 |
+
from post_filter import filter_corres
|
275 |
+
return filter_corres(*map(image_with_trf,imgs), corres, **self.post_filtering)
|
276 |
+
|
277 |
+
def save_output(self, output_path, corres ):
|
278 |
+
mkdir_for( output_path )
|
279 |
+
np.savez(open(output_path,'wb'), corres=corres)
|
280 |
+
|
281 |
+
|
282 |
+
|
283 |
+
if __name__ == '__main__':
|
284 |
+
Main().run_from_args(arg_parser().parse_args())
|
test_singlescale_recursive.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
from tqdm import tqdm
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import test_singlescale as tss
|
11 |
+
import core.functional as myF
|
12 |
+
from tools.viz import dbgfig, show_correspondences
|
13 |
+
|
14 |
+
|
15 |
+
def arg_parser(parser = None):
|
16 |
+
parser = parser or tss.arg_parser()
|
17 |
+
|
18 |
+
parser.add_argument('--rec-overlap', type=float, default=0.5, help='overlap between tiles in [0,0.5]')
|
19 |
+
parser.add_argument('--rec-score-thr', type=float, default=1, help='corres score threshold to guide fine levels')
|
20 |
+
parser.add_argument('--rec-fast-thr', type=float, default=0.1, help='prune block if less than `fast` corres fall in it')
|
21 |
+
|
22 |
+
return parser
|
23 |
+
|
24 |
+
|
25 |
+
class RecursivePUMP (tss.SingleScalePUMP):
|
26 |
+
""" Recursive PUMP:
|
27 |
+
1) find initial correspondences at a coarse scale,
|
28 |
+
2) refine them at a selection of finer scales
|
29 |
+
"""
|
30 |
+
def __init__(self, coarse_size=512, fine_size=512, rec_overlap=0.5, rec_score_thr=1.0,
|
31 |
+
rec_fast_thr = 0.1, **other_options ):
|
32 |
+
super().__init__(**other_options)
|
33 |
+
assert 10 < coarse_size < 1024
|
34 |
+
assert 10 < fine_size < 1024
|
35 |
+
assert 0 <= rec_overlap < 1
|
36 |
+
assert 0 < rec_fast_thr < 1
|
37 |
+
self.coarse_size = coarse_size
|
38 |
+
self.fine_size = fine_size
|
39 |
+
self.overlap = rec_overlap
|
40 |
+
self.score_thr = rec_score_thr
|
41 |
+
self.fast_thr = rec_fast_thr
|
42 |
+
|
43 |
+
@torch.no_grad()
|
44 |
+
def forward(self, img1, img2, ret='corres', dbg=()):
|
45 |
+
img1, sca1 = self.demultiplex_img_trf(img1, force=True)
|
46 |
+
img2, sca2 = self.demultiplex_img_trf(img2, force=True)
|
47 |
+
input_trfs = (sca1, sca2)
|
48 |
+
|
49 |
+
# coarse first level with low-res images
|
50 |
+
corres = self.coarse_correspondences(img1, img2)
|
51 |
+
|
52 |
+
# fine level: iterate on HQ blocks
|
53 |
+
accu1, accu2 = (self._make_accu(img1), self._make_accu(img2))
|
54 |
+
for block1, block2 in tqdm(list(self._enumerate_blocks(img1, img2, corres))):
|
55 |
+
# print(f"img1[{block1[}:{}, {}:{}]"
|
56 |
+
accus, trfs = tss.SingleScalePUMP.forward(self, block1, block2, ret='raw', dbg=dbg)
|
57 |
+
self._update_accu( accu1, accus[0], trfs[0][:2,2] )
|
58 |
+
self._update_accu( accu2, accus[1], trfs[1][:2,2] )
|
59 |
+
|
60 |
+
demul = lambda accu: (accu[:,:,:4].reshape(-1,4).clone(), accu[:,:,4].clone())
|
61 |
+
corres = demul(accu1), demul(accu2)
|
62 |
+
if dbgfig('corres', dbg): viz_correspondences(img1, img2, *corres, fig='last')
|
63 |
+
corres = [(myF.affmul(input_trfs,pos),score) for pos, score in corres] # rectify scaling etc.
|
64 |
+
if ret == 'raw': return corres, input_trfs
|
65 |
+
return self.reciprocal(*corres)
|
66 |
+
|
67 |
+
def coarse_correspondences(self, img1, img2, **kw):
|
68 |
+
# joint image resize, because relative size is important (multiscale)
|
69 |
+
shape1, shape2 = img1.shape[-2:], img2.shape[-2:]
|
70 |
+
if max(shape1 + shape2) > self.coarse_size:
|
71 |
+
f1 = self.coarse_size / max(shape1)
|
72 |
+
f2 = self.coarse_size / max(shape2)
|
73 |
+
f = min(f1, f2)
|
74 |
+
img1 = myF.imresize( img1, int(0.5+f*max(shape1)) )
|
75 |
+
img2 = myF.imresize( img2, int(0.5+f*max(shape2)) )
|
76 |
+
else:
|
77 |
+
f = 1
|
78 |
+
|
79 |
+
init_corres = tss.SingleScalePUMP.forward(self, img1, img2, **kw)
|
80 |
+
# show_correspondences(img1, img2, init_corres, fig='last')
|
81 |
+
corres = init_corres[init_corres[:,4] > self.score_thr]
|
82 |
+
print(f" keeping {len(corres)}/{len(init_corres)} corres with score > {self.score_thr} ...")
|
83 |
+
return corres
|
84 |
+
|
85 |
+
def _update_accu(self, accu, update, offset ):
|
86 |
+
pos, scores = update
|
87 |
+
H, W = scores.shape
|
88 |
+
offx, offy = map(lambda i: int(i/4), offset)
|
89 |
+
accu = accu[offy:offy+H, offx:offx+W]
|
90 |
+
better = accu[:,:,4] < scores
|
91 |
+
accu[:,:,4][better] = scores[better].float()
|
92 |
+
accu[:,:,0:4][better] = pos.reshape(H,W,4)[better]
|
93 |
+
|
94 |
+
def _enumerate_blocks(self, img1, img2, corres):
|
95 |
+
H1, W1, H2, W2 = img1.shape[1:] + img2.shape[1:]
|
96 |
+
size, step = self.fine_size, int(self.overlap * self.fine_size)
|
97 |
+
def regular_steps(size):
|
98 |
+
if size <= self.fine_size: return [0]
|
99 |
+
nb = int(np.ceil(size / step)) - 1 # garranted >= 1
|
100 |
+
return (np.linspace(0, size-self.fine_size, nb) / 4 + 0.5).astype(int) * 4
|
101 |
+
def translation(x,y):
|
102 |
+
res = torch.eye(3, device=img1.device)
|
103 |
+
res[0,2] = x
|
104 |
+
res[1,2] = y
|
105 |
+
return res
|
106 |
+
def block2(x2,y2):
|
107 |
+
return img2[:,y2:y2+size,x2:x2+size], translation(x2,y2)
|
108 |
+
cx1, cy1 = corres[:,0:2].T
|
109 |
+
|
110 |
+
for y1 in regular_steps(H1):
|
111 |
+
for x1 in regular_steps(W1):
|
112 |
+
block1 = (img1[:,y1:y1+size,x1:x1+size], translation(x1,y1))
|
113 |
+
c2 = corres[(y1<=cy1) & (cy1<y1+size) & (x1<=cx1) & (cx1<x1+size)]
|
114 |
+
nb_init = len(c2)
|
115 |
+
while len(c2):
|
116 |
+
cx2, cy2 = c2[:,2:4].T
|
117 |
+
x2, y2 = (int(max(0,min(W2-size,cx2.median()-size//2)) / 4 + 0.5) * 4,
|
118 |
+
int(max(0,min(H2-size,cy2.median()-size//2)) / 4 + 0.5) * 4)
|
119 |
+
inside = (y2<=cy2) & (cy2<y2+size) & (x2<=cx2) & (cx2<x2+size)
|
120 |
+
if not inside.any():
|
121 |
+
x2, y2 = c2[np.random.choice(len(c2)),2:4]
|
122 |
+
x2 = int(max(0,min(W2-size,x2-size//2)) / 4 + 0.5) * 4
|
123 |
+
y2 = int(max(0,min(H2-size,y2-size//2)) / 4 + 0.5) * 4
|
124 |
+
inside = (y2<=cy2) & (cy2<y2+size) & (x2<=cx2) & (cx2<x2+size)
|
125 |
+
|
126 |
+
if inside.sum()/nb_init >= self.fast_thr:
|
127 |
+
yield block1, block2(x2,y2)
|
128 |
+
|
129 |
+
c2 = c2[~inside] # remove
|
130 |
+
|
131 |
+
def _make_accu(self, img):
|
132 |
+
C, H, W = img.shape
|
133 |
+
return img.new_zeros(((H+3)//4, (W+3)//4, 5), dtype=torch.float32)
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
class Main (tss.Main):
|
138 |
+
@staticmethod
|
139 |
+
def build_matcher(args, device):
|
140 |
+
# set coarse and fine size based on now obsolete --resize argument
|
141 |
+
if isinstance(args.resize, int): args.resize = [args.resize]
|
142 |
+
if len(args.resize) == 1: args.resize *= 2
|
143 |
+
args.rec_coarse_size, args.rec_fine_size = args.resize
|
144 |
+
args.resize = 0 # disable it so that image loading does not downsize images
|
145 |
+
|
146 |
+
options = Main.get_options( args )
|
147 |
+
|
148 |
+
matcher = RecursivePUMP( coarse_size=args.rec_coarse_size, fine_size=args.rec_fine_size,
|
149 |
+
rec_overlap=args.rec_overlap, rec_score_thr=args.rec_score_thr, rec_fast_thr=args.rec_fast_thr,
|
150 |
+
**options)
|
151 |
+
|
152 |
+
return tss.Main.tune_matcher(matcher, **vars(args) ).to(device)
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == '__main__':
|
156 |
+
Main().run_from_args(arg_parser().parse_args())
|
tools/common.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def mkdir_for(file_path):
|
11 |
+
dirname = os.path.split(file_path)[0]
|
12 |
+
if dirname: os.makedirs(dirname, exist_ok=True)
|
13 |
+
return file_path
|
14 |
+
|
15 |
+
|
16 |
+
def model_size(model):
|
17 |
+
''' Computes the number of parameters of the model
|
18 |
+
'''
|
19 |
+
size = 0
|
20 |
+
for weights in model.state_dict().values():
|
21 |
+
size += np.prod(weights.shape)
|
22 |
+
return size
|
23 |
+
|
24 |
+
|
25 |
+
class cudnn_benchmark:
|
26 |
+
" context manager to temporarily disable cudnn benchmark "
|
27 |
+
def __init__(self, activate ):
|
28 |
+
self.activate = activate
|
29 |
+
def __enter__(self):
|
30 |
+
self.old_bm = torch.backends.cudnn.benchmark
|
31 |
+
torch.backends.cudnn.benchmark = self.activate
|
32 |
+
def __exit__(self, *args):
|
33 |
+
torch.backends.cudnn.benchmark = self.old_bm
|
34 |
+
|
35 |
+
|
36 |
+
def todevice(x, device, non_blocking=False):
|
37 |
+
""" Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
|
38 |
+
x: array, tensor, or container of such.
|
39 |
+
device: pytorch device or 'numpy'
|
40 |
+
"""
|
41 |
+
if isinstance(x, dict):
|
42 |
+
return {k:todevice(v, device) for k,v in x.items()}
|
43 |
+
|
44 |
+
if isinstance(x, (tuple,list)):
|
45 |
+
return type(x)(todevice(e, device) for e in x)
|
46 |
+
|
47 |
+
if device == 'numpy':
|
48 |
+
if isinstance(x, torch.Tensor):
|
49 |
+
x = x.detach().cpu().numpy()
|
50 |
+
elif x is not None:
|
51 |
+
if isinstance(x, np.ndarray):
|
52 |
+
x = torch.from_numpy(x)
|
53 |
+
x = x.to(device, non_blocking=non_blocking)
|
54 |
+
return x
|
55 |
+
|
56 |
+
def nparray( x ): return todevice(x, 'numpy')
|
57 |
+
def cpu( x ): return todevice(x, 'cpu')
|
58 |
+
def cuda( x ): return todevice(x, 'cuda')
|
59 |
+
|
60 |
+
|
61 |
+
def image( img, with_trf=False ):
|
62 |
+
" convert a torch.Tensor to a numpy image (H, W, 3) "
|
63 |
+
def convert_image(img):
|
64 |
+
if isinstance(img, torch.Tensor):
|
65 |
+
if img.dtype is not torch.uint8:
|
66 |
+
img = img * 255
|
67 |
+
if img.min() < -10:
|
68 |
+
img = img.clone()
|
69 |
+
for i, (mean, std) in enumerate(zip([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])):
|
70 |
+
img[i] *= std
|
71 |
+
img[i] += 255*mean
|
72 |
+
img = img.byte()
|
73 |
+
if img.shape[0] <= 3:
|
74 |
+
img = img.permute(1,2,0)
|
75 |
+
return img
|
76 |
+
|
77 |
+
if isinstance(img, tuple):
|
78 |
+
if with_trf:
|
79 |
+
return nparray(convert_image(img[0])), nparray(img[1])
|
80 |
+
else:
|
81 |
+
img = img[0]
|
82 |
+
return nparray(convert_image(img))
|
83 |
+
|
84 |
+
|
85 |
+
def image_with_trf( img ):
|
86 |
+
return image(img, with_trf=True)
|
87 |
+
|
88 |
+
class ToTensor:
|
89 |
+
" numpy images to float tensors "
|
90 |
+
def __call__(self, x):
|
91 |
+
assert x.ndim == 4 and x.shape[3] == 3
|
92 |
+
if isinstance(x, np.ndarray):
|
93 |
+
x = torch.from_numpy(x)
|
94 |
+
assert x.dtype == torch.uint8
|
95 |
+
return x.permute(0, 3, 1, 2).float() / 255
|
tools/trainer.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
import pdb; bb = pdb.set_trace
|
6 |
+
from tqdm import tqdm
|
7 |
+
from collections import defaultdict
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import DataParallel
|
12 |
+
|
13 |
+
from .common import todevice
|
14 |
+
|
15 |
+
|
16 |
+
class Trainer (nn.Module):
|
17 |
+
""" Helper class to train a deep network.
|
18 |
+
Overload this class `forward_backward` for your actual needs.
|
19 |
+
|
20 |
+
Usage:
|
21 |
+
train = Trainer(net, loss, optimizer)
|
22 |
+
for epoch in range(n_epochs):
|
23 |
+
train()
|
24 |
+
"""
|
25 |
+
def __init__(self, net, loss, optimizer, epoch=0):
|
26 |
+
super().__init__()
|
27 |
+
self.net = net
|
28 |
+
self.loss = loss
|
29 |
+
self.optimizer = optimizer
|
30 |
+
self.epoch = epoch
|
31 |
+
|
32 |
+
@property
|
33 |
+
def device(self):
|
34 |
+
return next(self.net.parameters()).device
|
35 |
+
|
36 |
+
@property
|
37 |
+
def model(self):
|
38 |
+
return self.net.module if isinstance(self.net, DataParallel) else self.net
|
39 |
+
|
40 |
+
def distribute(self):
|
41 |
+
self.net = DataParallel(self.net) # DataDistributed not implemented yet
|
42 |
+
|
43 |
+
def __call__(self, data_loader):
|
44 |
+
print(f'>> Training (epoch {self.epoch} --> {self.epoch+1})')
|
45 |
+
self.net.train()
|
46 |
+
|
47 |
+
stats = defaultdict(list)
|
48 |
+
|
49 |
+
for batch in tqdm(data_loader):
|
50 |
+
batch = todevice(batch, self.device)
|
51 |
+
|
52 |
+
# compute gradient and do model update
|
53 |
+
self.optimizer.zero_grad()
|
54 |
+
details = self.forward_backward(batch)
|
55 |
+
self.optimizer.step()
|
56 |
+
|
57 |
+
for key, val in details.items():
|
58 |
+
stats[key].append( val )
|
59 |
+
|
60 |
+
self.epoch += 1
|
61 |
+
|
62 |
+
print(" Summary of losses during this epoch:")
|
63 |
+
for loss_name, vals in stats.items():
|
64 |
+
N = 1 + len(vals)//10
|
65 |
+
print(f" - {loss_name:10}: {avg(vals[:N]):.3f} --> {avg(vals[-N:]):.3f} (avg: {avg(vals):.3f})")
|
66 |
+
|
67 |
+
def forward_backward(self, inputs):
|
68 |
+
raise NotImplementedError()
|
69 |
+
|
70 |
+
def save(self, path):
|
71 |
+
print(f"\n>> Saving model to {path}")
|
72 |
+
|
73 |
+
data = {'model': self.model.state_dict(),
|
74 |
+
'optimizer': self.optimizer.state_dict(),
|
75 |
+
'loss': self.loss.state_dict(),
|
76 |
+
'epoch': self.epoch}
|
77 |
+
|
78 |
+
torch.save(data, open(path,'wb'))
|
79 |
+
|
80 |
+
def load(self, path, resume=True):
|
81 |
+
print(f">> Loading weights from {path} ...")
|
82 |
+
checkpoint = torch.load(path, map_location='cpu')
|
83 |
+
assert isinstance(checkpoint, dict)
|
84 |
+
|
85 |
+
self.net.load_state_dict(checkpoint['model'])
|
86 |
+
if resume:
|
87 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
88 |
+
self.loss.load_state_dict(checkpoint['optimizer'])
|
89 |
+
self.epoch = checkpoint['epoch']
|
90 |
+
print(f" Resuming training at Epoch {self.epoch}!")
|
91 |
+
|
92 |
+
|
93 |
+
def get_loss( loss ):
|
94 |
+
""" returns a tuple (loss, dictionary of loss details)
|
95 |
+
"""
|
96 |
+
assert isinstance(loss, dict)
|
97 |
+
grads = None
|
98 |
+
|
99 |
+
k,l = next(iter(loss.items())) # first item is assumed to be the main loss
|
100 |
+
if isinstance(l, tuple):
|
101 |
+
l, grads = l
|
102 |
+
loss[k] = l
|
103 |
+
|
104 |
+
return (l, grads), {k:float(v) for k,v in loss.items()}
|
105 |
+
|
106 |
+
|
107 |
+
def backward( loss ):
|
108 |
+
if isinstance(loss, tuple):
|
109 |
+
loss, grads = loss
|
110 |
+
else:
|
111 |
+
loss, grads = (loss, None)
|
112 |
+
|
113 |
+
assert loss == loss, 'loss is NaN'
|
114 |
+
|
115 |
+
if grads is None:
|
116 |
+
loss.backward()
|
117 |
+
else:
|
118 |
+
# dictionary of separate subgraphs
|
119 |
+
for var,grad in grads:
|
120 |
+
var.backward(grad)
|
121 |
+
return float(loss)
|
122 |
+
|
123 |
+
|
124 |
+
def avg( lis ):
|
125 |
+
return sum(lis) / len(lis)
|
tools/viz.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
import sys
|
6 |
+
from pdb import set_trace as bb
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import matplotlib.pyplot as pl; pl.ion()
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from core import functional as myF
|
15 |
+
from .common import cpu, nparray, image, image_with_trf
|
16 |
+
|
17 |
+
|
18 |
+
def dbgfig(*args, **kwargs):
|
19 |
+
assert len(args) >= 2
|
20 |
+
dbg = args[-1]
|
21 |
+
if isinstance(dbg, str):
|
22 |
+
dbg = dbg.split()
|
23 |
+
for name in args[:-1]:
|
24 |
+
if {name,'all'} & set(dbg):
|
25 |
+
return pl.figure(name, **kwargs)
|
26 |
+
return False
|
27 |
+
|
28 |
+
|
29 |
+
def noticks(ax=None):
|
30 |
+
if ax is None: ax = pl.gca()
|
31 |
+
ax.set_xticks(())
|
32 |
+
ax.set_yticks(())
|
33 |
+
return ax
|
34 |
+
|
35 |
+
|
36 |
+
def plot_grid( corres, ax1, ax2=None, marker='+' ):
|
37 |
+
""" corres = Nx2 or Nx4 list of correspondences
|
38 |
+
"""
|
39 |
+
if marker is True: marker = '+'
|
40 |
+
|
41 |
+
corres = nparray(corres)
|
42 |
+
# make beautiful colors
|
43 |
+
center = corres[:,[1,0]].mean(axis=0)
|
44 |
+
colors = np.arctan2(*(corres[:,[1,0]] - center).T)
|
45 |
+
colors = np.int32(64*colors/np.pi) % 128
|
46 |
+
|
47 |
+
all_colors = np.unique(colors)
|
48 |
+
palette = {m:pl.cm.hsv(i/float(len(all_colors))) for i,m in enumerate(all_colors)}
|
49 |
+
|
50 |
+
for m in all_colors:
|
51 |
+
x, y = corres[colors==m,0:2].T
|
52 |
+
ax1.plot(x, y, marker, ms=10, mew=2, color=palette[m], scalex=0, scaley=0)
|
53 |
+
|
54 |
+
if not ax2: return
|
55 |
+
for m in all_colors:
|
56 |
+
x, y = corres[colors==m,2:4].T
|
57 |
+
ax2.plot(x, y, marker, ms=10, mew=2, color=palette[m], scalex=0, scaley=0)
|
58 |
+
|
59 |
+
|
60 |
+
def show_correspondences( img0, img1, corres, F=None, fig='last', show_grid=True, bb=None, clf=False):
|
61 |
+
img0, trf0 = img0 if isinstance(img0, tuple) else (img0, torch.eye(3))
|
62 |
+
img1, trf1 = img1 if isinstance(img1, tuple) else (img1, torch.eye(3))
|
63 |
+
if not bb: pl.ioff()
|
64 |
+
fig, axes = pl.subplots(2, 2, num=fig_num(fig, 'viz_corres'))
|
65 |
+
for i, ax in enumerate(axes.ravel()):
|
66 |
+
if clf: ax.cla()
|
67 |
+
noticks(ax).numaxis = i % 2
|
68 |
+
ax.imshow( [image(img0),image(img1)][i%2] )
|
69 |
+
|
70 |
+
if corres.shape == (3,3): # corres is an homography matrix
|
71 |
+
from pytools.hfuncs import applyh
|
72 |
+
H, W = axes[0,0].images[0].get_size()
|
73 |
+
pos1 = np.mgrid[:H,:W].reshape(2,-1)[::-1].T
|
74 |
+
pos2 = applyh(corres, pos1)
|
75 |
+
corres = np.concatenate((pos1,pos2), axis=-1)
|
76 |
+
|
77 |
+
inv = np.linalg.inv
|
78 |
+
corres = myF.affmul((inv(nparray(trf0)),inv(nparray(trf1))), nparray(corres)) # image are already downscaled
|
79 |
+
print(f">> Displaying {len(corres)} correspondences (move you mouse over the images)")
|
80 |
+
|
81 |
+
(ax1, ax2), (ax3, ax4) = axes
|
82 |
+
if corres.shape[-1] > 4:
|
83 |
+
corres = corres[corres[:,4]>0,:] # select non-null correspondences
|
84 |
+
if show_grid: plot_grid(corres, ax3, ax4, marker=show_grid)
|
85 |
+
|
86 |
+
def mouse_move(event):
|
87 |
+
if event.inaxes==None: return
|
88 |
+
numaxis = event.inaxes.numaxis
|
89 |
+
if numaxis<0: return
|
90 |
+
x,y = event.xdata, event.ydata
|
91 |
+
ax1.lines.clear()
|
92 |
+
ax2.lines.clear()
|
93 |
+
sl = slice(2*numaxis, 2*(numaxis+1))
|
94 |
+
n = np.sum((corres[:,sl] - [x,y])**2,axis=1).argmin() # find nearest point
|
95 |
+
print("\rdisplaying #%d (%d,%d) --> (%d,%d), score=%g, code=%g" % (n,
|
96 |
+
corres[n,0],corres[n,1],corres[n,2],corres[n,3],
|
97 |
+
corres[n,4] if corres.shape[-1] > 4 else np.nan,
|
98 |
+
corres[n,5] if corres.shape[-1] > 5 else np.nan), end=' '*7);sys.stdout.flush()
|
99 |
+
x,y = corres[n,0:2]
|
100 |
+
ax1.plot(x, y, '+', ms=10, mew=2, color='blue', scalex=False, scaley=False)
|
101 |
+
x,y = corres[n,2:4]
|
102 |
+
ax2.plot(x, y, '+', ms=10, mew=2, color='red', scalex=False, scaley=False)
|
103 |
+
if F is not None:
|
104 |
+
ax = None
|
105 |
+
if numaxis == 0:
|
106 |
+
line = corres[n,0:2] @ F[:2] + F[2]
|
107 |
+
ax = ax2
|
108 |
+
if numaxis == 1:
|
109 |
+
line = corres[n,2:4] @ F.T[:2] + F.T[2]
|
110 |
+
ax = ax1
|
111 |
+
if ax:
|
112 |
+
x = np.linspace(-10000,10000,2)
|
113 |
+
y = (line[2]+line[0]*x) / -line[1]
|
114 |
+
ax.plot(x, y, '-', scalex=0, scaley=0)
|
115 |
+
|
116 |
+
# we redraw only the concerned axes
|
117 |
+
renderer = fig.canvas.get_renderer()
|
118 |
+
ax1.draw(renderer)
|
119 |
+
ax2.draw(renderer)
|
120 |
+
fig.canvas.blit(ax1.bbox)
|
121 |
+
fig.canvas.blit(ax2.bbox)
|
122 |
+
|
123 |
+
cid_move = fig.canvas.mpl_connect('motion_notify_event',mouse_move)
|
124 |
+
pl.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0.02, hspace=0.02)
|
125 |
+
bb() if bb else pl.show()
|
126 |
+
fig.canvas.mpl_disconnect(cid_move)
|
127 |
+
|
128 |
+
|
129 |
+
def closest( grid, event ):
|
130 |
+
query = (event.xdata, event.ydata)
|
131 |
+
n = np.linalg.norm(grid.reshape(-1,2) - query, axis=1).argmin()
|
132 |
+
return np.unravel_index(n, grid.shape[:2])
|
133 |
+
|
134 |
+
|
135 |
+
def local_maxima( arr2d, top=5 ):
|
136 |
+
maxpooled = F.max_pool2d( arr2d[None, None], 3, padding=1, stride=1)[0,0]
|
137 |
+
local_maxima = (arr2d == maxpooled).nonzero()
|
138 |
+
order = arr2d[local_maxima.split(1,dim=1)].ravel().argsort()
|
139 |
+
return local_maxima[order[-5:]].T
|
140 |
+
|
141 |
+
|
142 |
+
def fig_num( fig, default, clf=False ):
|
143 |
+
if fig == 'last': num = pl.gcf().number
|
144 |
+
elif fig: num = fig.number
|
145 |
+
else: num = default
|
146 |
+
if clf: pl.figure(num).clf()
|
147 |
+
return num
|
148 |
+
|
149 |
+
|
150 |
+
def viz_correlation_maps( img1, img2, corr, level=0, fig=None, grid1=None, grid2=None, show_grid=False, bb=bb, **kw ):
|
151 |
+
fig, ((ax1, ax2), (ax4, ax3)) = pl.subplots(2, 2, num=fig_num(fig, 'viz_correlation_maps', clf=True))
|
152 |
+
img1 = image(img1)
|
153 |
+
img2 = image(img2)
|
154 |
+
noticks(ax1).imshow( img1 )
|
155 |
+
noticks(ax2).imshow( img2 )
|
156 |
+
ax4.hist(corr.ravel()[7:7777777:7].cpu().numpy(), bins=50)
|
157 |
+
|
158 |
+
if isinstance(corr, tuple):
|
159 |
+
H1, W1 = corr.grid.shape[:2]
|
160 |
+
corr = torch.from_numpy(corr.res_map).view(H1,W1,*corr.res_map.shape[-2:])
|
161 |
+
|
162 |
+
if grid1 is None:
|
163 |
+
s1 = int(0.5 + np.sqrt(img1.size / (3 * corr[...,0,0].numel()))) # scale factor between img1 and corr
|
164 |
+
grid1 = nparray(torch.ones_like(corr[:,:,0,0]).nonzero()*s1)[:,1::-1]
|
165 |
+
if level == 0: grid1 += s1//2
|
166 |
+
if show_grid: plot_grid(grid1, ax1)
|
167 |
+
grid1 = nparray(grid1).reshape(*corr[:,:,0,0].shape,2)
|
168 |
+
|
169 |
+
if grid2 is None:
|
170 |
+
s2 = int(0.5 + np.sqrt(img2.size / (3 * corr[0,0,...].numel()))) # scale factor between img2 and corr
|
171 |
+
grid2 = nparray(torch.ones_like(corr[0,0]).nonzero()*s2)[:,::-1]
|
172 |
+
grid2 = nparray(grid2).reshape(*corr.shape[2:],2)
|
173 |
+
|
174 |
+
def mouse_move(ev):
|
175 |
+
if ev.inaxes is ax1:
|
176 |
+
ax3.images.clear()
|
177 |
+
n = closest(grid1, ev)
|
178 |
+
ax3.imshow(corr[n].cpu().float(), vmin=0, **kw)
|
179 |
+
|
180 |
+
# find local maxima
|
181 |
+
lm = nparray(local_maxima(corr[n]))
|
182 |
+
for ax in (ax3, ax2):
|
183 |
+
if ax is ax2 and not show_grid:
|
184 |
+
ax1.lines.clear()
|
185 |
+
ax1.plot(*grid1[n], 'xr', ms=10, scalex=0, scaley=0)
|
186 |
+
ax.lines.clear()
|
187 |
+
x, y = grid2[y,x].T if ax is ax2 else lm[::-1]
|
188 |
+
if ax is not ax3:
|
189 |
+
ax.plot(x, y, 'xr', ms=10, scalex=0, scaley=0, label='local maxima')
|
190 |
+
print(f"\rCorr channel {n}. Min={corr[n].min():g}, Avg={corr[n].mean():g}, Max={corr[n].max():g} ", end='')
|
191 |
+
|
192 |
+
mouse_move(FakeEvent(0,0,inaxes=ax1))
|
193 |
+
cid_move = fig.canvas.mpl_connect('motion_notify_event', mouse_move)
|
194 |
+
pl.subplots_adjust(0,0,1,1,0,0)
|
195 |
+
pl.sca(ax4)
|
196 |
+
if bb: bb(); fig.canvas.mpl_disconnect(cid_move)
|
197 |
+
|
198 |
+
def viz_correspondences( img1, img2, corres1, corres2, fig=None ):
|
199 |
+
img1, img2 = map(image, (img1, img2))
|
200 |
+
fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = pl.subplots(3,2, num=fig_num(fig, 'viz_correspondences'))
|
201 |
+
for ax in fig.axes: noticks(ax)
|
202 |
+
ax1.imshow( img1 )
|
203 |
+
ax2.imshow( img2 )
|
204 |
+
ax3.imshow( img1 )
|
205 |
+
ax4.imshow( img2 )
|
206 |
+
corres1, corres2 = map(cpu, (corres1, corres2))
|
207 |
+
plot_grid( corres1[0], ax1, ax2 )
|
208 |
+
plot_grid( corres2[0], ax3, ax4 )
|
209 |
+
|
210 |
+
corres1, corres2 = corres1[1].float(), corres2[1].float()
|
211 |
+
ceiling = np.ceil(max(corres1.max(), corres2.max()).item())
|
212 |
+
ax5.imshow( corres1, vmin=0, vmax=ceiling )
|
213 |
+
ax6.imshow( corres2, vmin=0, vmax=ceiling )
|
214 |
+
bb()
|
215 |
+
|
216 |
+
|
217 |
+
class FakeEvent:
|
218 |
+
def __init__(self, xdata, ydata, **kw):
|
219 |
+
self.xdata = xdata
|
220 |
+
self.ydata = ydata
|
221 |
+
for name, val in kw.items():
|
222 |
+
setattr(self, name, val)
|
223 |
+
|
224 |
+
|
225 |
+
def show_random_pairs( db, pair_idxs=None, **kw ):
|
226 |
+
print('Showing random pairs from', db)
|
227 |
+
|
228 |
+
if pair_idxs is None:
|
229 |
+
pair_idxs = np.random.permutation(len(db))
|
230 |
+
|
231 |
+
for pair_idx in pair_idxs:
|
232 |
+
print(f'{pair_idx=}')
|
233 |
+
try:
|
234 |
+
img1_path, img2_path = map(db.imgs.get_image_path, db.pairs[pair_idx])
|
235 |
+
print(f'{img1_path=}\n{img2_path=}')
|
236 |
+
if hasattr(db, 'get_corres_path'):
|
237 |
+
print(f'corres_path = {db.get_corres_path(pair_idx)}')
|
238 |
+
except: pass
|
239 |
+
(img1, img2), gt = db[pair_idx]
|
240 |
+
|
241 |
+
if 'corres' in gt:
|
242 |
+
corres = gt['corres']
|
243 |
+
else:
|
244 |
+
# make corres from homography
|
245 |
+
from datasets.utils import corres_from_homography
|
246 |
+
corres = corres_from_homography(gt['homography'], *img1.size)
|
247 |
+
|
248 |
+
show_correspondences(img1, img2, corres, **kw)
|
249 |
+
|
250 |
+
|
251 |
+
if __name__=='__main__':
|
252 |
+
import argparse
|
253 |
+
import test_singlescale as pump
|
254 |
+
|
255 |
+
parser = argparse.ArgumentParser('Correspondence visualization')
|
256 |
+
parser.add_argument('--img1', required=True, help='path to first image')
|
257 |
+
parser.add_argument('--img2', required=True, help='path to second image')
|
258 |
+
parser.add_argument('--corres', required=True, help='path to correspondences')
|
259 |
+
args = parser.parse_args()
|
260 |
+
|
261 |
+
corres = np.load(args.corres)['corres']
|
262 |
+
|
263 |
+
args.resize = 0 # don't resize images
|
264 |
+
imgs = tuple(map(image, pump.Main.load_images(args)))
|
265 |
+
|
266 |
+
show_correspondences(*imgs, corres)
|
train.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022-present NAVER Corp.
|
2 |
+
# CC BY-NC-SA 4.0
|
3 |
+
# Available only for non-commercial use
|
4 |
+
|
5 |
+
from pdb import set_trace as bb
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import torch.optim as optim
|
9 |
+
import torchvision.transforms as tvf
|
10 |
+
|
11 |
+
from tools import common, trainer
|
12 |
+
from datasets import *
|
13 |
+
from core.conv_mixer import ConvMixer
|
14 |
+
from core.losses import *
|
15 |
+
|
16 |
+
|
17 |
+
def parse_args():
|
18 |
+
import argparse
|
19 |
+
parser = argparse.ArgumentParser("Script to train PUMP")
|
20 |
+
|
21 |
+
parser.add_argument("--pretrained", type=str, default="", help='pretrained model path')
|
22 |
+
parser.add_argument("--save-path", type=str, required=True, help='directory to save model')
|
23 |
+
|
24 |
+
parser.add_argument("--epochs", type=int, default=50, help='number of training epochs')
|
25 |
+
parser.add_argument("--batch-size", "--bs", type=int, default=16, help="batch size")
|
26 |
+
parser.add_argument("--learning-rate", "--lr", type=str, default=1e-4)
|
27 |
+
parser.add_argument("--weight-decay", "--wd", type=float, default=5e-4)
|
28 |
+
|
29 |
+
parser.add_argument("--threads", type=int, default=8, help='number of worker threads')
|
30 |
+
parser.add_argument("--device", default='cuda')
|
31 |
+
|
32 |
+
args = parser.parse_args()
|
33 |
+
return args
|
34 |
+
|
35 |
+
|
36 |
+
def main( args ):
|
37 |
+
device = args.device
|
38 |
+
common.mkdir_for(args.save_path)
|
39 |
+
|
40 |
+
# Create data loader
|
41 |
+
db = BalancedCatImagePairs(
|
42 |
+
3125, SyntheticImagePairs(RandomWebImages(0,52),distort='RandomTilting(0.5)'),
|
43 |
+
4875, SyntheticImagePairs(SfM120k_Images(),distort='RandomTilting(0.5)'),
|
44 |
+
8000, SfM120k_Pairs())
|
45 |
+
|
46 |
+
db = FastPairLoader(db,
|
47 |
+
crop=256, transform='RandomRotation(20), RandomScale(256,1536,ar=1.3,can_upscale=True), PixelNoise(25)',
|
48 |
+
p_swap=0.5, p_flip=0.5, scale_jitter=0.5)
|
49 |
+
|
50 |
+
print("Training image database =", db)
|
51 |
+
data_loader = torch.utils.data.DataLoader(db, batch_size=args.batch_size, shuffle=True,
|
52 |
+
num_workers=args.threads, collate_fn=collate_ordered, pin_memory=False, drop_last=True,
|
53 |
+
worker_init_fn=WorkerWithRngInit())
|
54 |
+
|
55 |
+
# create network
|
56 |
+
net = ConvMixer(output_dim=128, hidden_dim=512, depth=7, patch_size=4, kernel_size=9)
|
57 |
+
print(f"\n>> Creating {type(net).__name__} net ( Model size: {common.model_size(net)/1e6:.1f}M parameters )")
|
58 |
+
|
59 |
+
# create losses
|
60 |
+
loss = MultiLoss(alpha=0.3,
|
61 |
+
loss_sup = PixelAPLoss(nq=20, inner_bw=True, sampler=NghSampler(ngh=7)),
|
62 |
+
loss_unsup = DeepMatchingLoss(eps=0.03))
|
63 |
+
|
64 |
+
# create optimizer
|
65 |
+
optimizer = optim.Adam( [p for p in net.parameters() if p.requires_grad],
|
66 |
+
lr=args.learning_rate, weight_decay=args.weight_decay)
|
67 |
+
|
68 |
+
train = MyTrainer(net, loss, optimizer).to(device)
|
69 |
+
|
70 |
+
# initialization
|
71 |
+
final_model_path = osp.join(args.save_path,'model.pt')
|
72 |
+
last_model_path = osp.join(args.save_path,'model.pt.last')
|
73 |
+
if osp.exists( final_model_path ):
|
74 |
+
print('Already trained, nothing to do!')
|
75 |
+
return
|
76 |
+
elif args.pretrained:
|
77 |
+
train.load( args.pretrained )
|
78 |
+
elif osp.exists( last_model_path ):
|
79 |
+
train.load( last_model_path )
|
80 |
+
|
81 |
+
train = train.to(args.device)
|
82 |
+
if ',' in os.environ.get('CUDA_VISIBLE_DEVICES',''):
|
83 |
+
train.distribute()
|
84 |
+
|
85 |
+
# Training loop #
|
86 |
+
while train.epoch < args.epochs:
|
87 |
+
# shuffle dataset (select new pairs)
|
88 |
+
data_loader.dataset.set_epoch(train.epoch)
|
89 |
+
|
90 |
+
train(data_loader)
|
91 |
+
|
92 |
+
train.save(last_model_path)
|
93 |
+
|
94 |
+
# save final model
|
95 |
+
torch.save(train.model.state_dict(), open(final_model_path,'wb'))
|
96 |
+
|
97 |
+
|
98 |
+
totensor = tvf.Compose([
|
99 |
+
common.ToTensor(),
|
100 |
+
tvf.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
101 |
+
])
|
102 |
+
|
103 |
+
class MyTrainer (trainer.Trainer):
|
104 |
+
""" This class implements the network training.
|
105 |
+
Below is the function I need to overload to explain how to do the backprop.
|
106 |
+
"""
|
107 |
+
def forward_backward(self, inputs):
|
108 |
+
assert torch.is_grad_enabled() and self.net.training
|
109 |
+
|
110 |
+
(img1, img2), labels = inputs
|
111 |
+
output1 = self.net(totensor(img1))
|
112 |
+
output2 = self.net(totensor(img2))
|
113 |
+
|
114 |
+
loss, details = trainer.get_loss(self.loss(output1, output2, img1=img1, img2=img2, **labels))
|
115 |
+
trainer.backward(loss)
|
116 |
+
return details
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
main(parse_args())
|