Philippe Weinzaepfel commited on
Commit
3ef85e9
·
0 Parent(s):

huggingface demo

Browse files
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.onnx filter=lfs diff=lfs merge=lfs -text
13
+ *.ot filter=lfs diff=lfs merge=lfs -text
14
+ *.parquet filter=lfs diff=lfs merge=lfs -text
15
+ *.pb filter=lfs diff=lfs merge=lfs -text
16
+ *.pt filter=lfs diff=lfs merge=lfs -text
17
+ *.pth filter=lfs diff=lfs merge=lfs -text
18
+ *.rar filter=lfs diff=lfs merge=lfs -text
19
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
21
+ *.tflite filter=lfs diff=lfs merge=lfs -text
22
+ *.tgz filter=lfs diff=lfs merge=lfs -text
23
+ *.wasm filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PUMP
2
+ Copyright (c) 2022-present NAVER Corp.
3
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4
+
5
+ A summary of the CC BY-NC-SA 4.0 license is located here:
6
+ https://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ ---
9
+
10
+ Attribution-NonCommercial-ShareAlike 4.0 International
11
+
12
+ =======================================================================
13
+
14
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
15
+ does not provide legal services or legal advice. Distribution of
16
+ Creative Commons public licenses does not create a lawyer-client or
17
+ other relationship. Creative Commons makes its licenses and related
18
+ information available on an "as-is" basis. Creative Commons gives no
19
+ warranties regarding its licenses, any material licensed under their
20
+ terms and conditions, or any related information. Creative Commons
21
+ disclaims all liability for damages resulting from their use to the
22
+ fullest extent possible.
23
+
24
+ Using Creative Commons Public Licenses
25
+
26
+ Creative Commons public licenses provide a standard set of terms and
27
+ conditions that creators and other rights holders may use to share
28
+ original works of authorship and other material subject to copyright
29
+ and certain other rights specified in the public license below. The
30
+ following considerations are for informational purposes only, are not
31
+ exhaustive, and do not form part of our licenses.
32
+
33
+ Considerations for licensors: Our public licenses are
34
+ intended for use by those authorized to give the public
35
+ permission to use material in ways otherwise restricted by
36
+ copyright and certain other rights. Our licenses are
37
+ irrevocable. Licensors should read and understand the terms
38
+ and conditions of the license they choose before applying it.
39
+ Licensors should also secure all rights necessary before
40
+ applying our licenses so that the public can reuse the
41
+ material as expected. Licensors should clearly mark any
42
+ material not subject to the license. This includes other CC-
43
+ licensed material, or material used under an exception or
44
+ limitation to copyright. More considerations for licensors:
45
+ wiki.creativecommons.org/Considerations_for_licensors
46
+
47
+ Considerations for the public: By using one of our public
48
+ licenses, a licensor grants the public permission to use the
49
+ licensed material under specified terms and conditions. If
50
+ the licensor's permission is not necessary for any reason--for
51
+ example, because of any applicable exception or limitation to
52
+ copyright--then that use is not regulated by the license. Our
53
+ licenses grant only permissions under copyright and certain
54
+ other rights that a licensor has authority to grant. Use of
55
+ the licensed material may still be restricted for other
56
+ reasons, including because others have copyright or other
57
+ rights in the material. A licensor may make special requests,
58
+ such as asking that all changes be marked or described.
59
+ Although not required by our licenses, you are encouraged to
60
+ respect those requests where reasonable. More considerations
61
+ for the public:
62
+ wiki.creativecommons.org/Considerations_for_licensees
63
+
64
+ =======================================================================
65
+
66
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
67
+ Public License
68
+
69
+ By exercising the Licensed Rights (defined below), You accept and agree
70
+ to be bound by the terms and conditions of this Creative Commons
71
+ Attribution-NonCommercial-ShareAlike 4.0 International Public License
72
+ ("Public License"). To the extent this Public License may be
73
+ interpreted as a contract, You are granted the Licensed Rights in
74
+ consideration of Your acceptance of these terms and conditions, and the
75
+ Licensor grants You such rights in consideration of benefits the
76
+ Licensor receives from making the Licensed Material available under
77
+ these terms and conditions.
78
+
79
+
80
+ Section 1 -- Definitions.
81
+
82
+ a. Adapted Material means material subject to Copyright and Similar
83
+ Rights that is derived from or based upon the Licensed Material
84
+ and in which the Licensed Material is translated, altered,
85
+ arranged, transformed, or otherwise modified in a manner requiring
86
+ permission under the Copyright and Similar Rights held by the
87
+ Licensor. For purposes of this Public License, where the Licensed
88
+ Material is a musical work, performance, or sound recording,
89
+ Adapted Material is always produced where the Licensed Material is
90
+ synched in timed relation with a moving image.
91
+
92
+ b. Adapter's License means the license You apply to Your Copyright
93
+ and Similar Rights in Your contributions to Adapted Material in
94
+ accordance with the terms and conditions of this Public License.
95
+
96
+ c. BY-NC-SA Compatible License means a license listed at
97
+ creativecommons.org/compatiblelicenses, approved by Creative
98
+ Commons as essentially the equivalent of this Public License.
99
+
100
+ d. Copyright and Similar Rights means copyright and/or similar rights
101
+ closely related to copyright including, without limitation,
102
+ performance, broadcast, sound recording, and Sui Generis Database
103
+ Rights, without regard to how the rights are labeled or
104
+ categorized. For purposes of this Public License, the rights
105
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
106
+ Rights.
107
+
108
+ e. Effective Technological Measures means those measures that, in the
109
+ absence of proper authority, may not be circumvented under laws
110
+ fulfilling obligations under Article 11 of the WIPO Copyright
111
+ Treaty adopted on December 20, 1996, and/or similar international
112
+ agreements.
113
+
114
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
115
+ any other exception or limitation to Copyright and Similar Rights
116
+ that applies to Your use of the Licensed Material.
117
+
118
+ g. License Elements means the license attributes listed in the name
119
+ of a Creative Commons Public License. The License Elements of this
120
+ Public License are Attribution, NonCommercial, and ShareAlike.
121
+
122
+ h. Licensed Material means the artistic or literary work, database,
123
+ or other material to which the Licensor applied this Public
124
+ License.
125
+
126
+ i. Licensed Rights means the rights granted to You subject to the
127
+ terms and conditions of this Public License, which are limited to
128
+ all Copyright and Similar Rights that apply to Your use of the
129
+ Licensed Material and that the Licensor has authority to license.
130
+
131
+ j. Licensor means the individual(s) or entity(ies) granting rights
132
+ under this Public License.
133
+
134
+ k. NonCommercial means not primarily intended for or directed towards
135
+ commercial advantage or monetary compensation. For purposes of
136
+ this Public License, the exchange of the Licensed Material for
137
+ other material subject to Copyright and Similar Rights by digital
138
+ file-sharing or similar means is NonCommercial provided there is
139
+ no payment of monetary compensation in connection with the
140
+ exchange.
141
+
142
+ l. Share means to provide material to the public by any means or
143
+ process that requires permission under the Licensed Rights, such
144
+ as reproduction, public display, public performance, distribution,
145
+ dissemination, communication, or importation, and to make material
146
+ available to the public including in ways that members of the
147
+ public may access the material from a place and at a time
148
+ individually chosen by them.
149
+
150
+ m. Sui Generis Database Rights means rights other than copyright
151
+ resulting from Directive 96/9/EC of the European Parliament and of
152
+ the Council of 11 March 1996 on the legal protection of databases,
153
+ as amended and/or succeeded, as well as other essentially
154
+ equivalent rights anywhere in the world.
155
+
156
+ n. You means the individual or entity exercising the Licensed Rights
157
+ under this Public License. Your has a corresponding meaning.
158
+
159
+
160
+ Section 2 -- Scope.
161
+
162
+ a. License grant.
163
+
164
+ 1. Subject to the terms and conditions of this Public License,
165
+ the Licensor hereby grants You a worldwide, royalty-free,
166
+ non-sublicensable, non-exclusive, irrevocable license to
167
+ exercise the Licensed Rights in the Licensed Material to:
168
+
169
+ a. reproduce and Share the Licensed Material, in whole or
170
+ in part, for NonCommercial purposes only; and
171
+
172
+ b. produce, reproduce, and Share Adapted Material for
173
+ NonCommercial purposes only.
174
+
175
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
176
+ Exceptions and Limitations apply to Your use, this Public
177
+ License does not apply, and You do not need to comply with
178
+ its terms and conditions.
179
+
180
+ 3. Term. The term of this Public License is specified in Section
181
+ 6(a).
182
+
183
+ 4. Media and formats; technical modifications allowed. The
184
+ Licensor authorizes You to exercise the Licensed Rights in
185
+ all media and formats whether now known or hereafter created,
186
+ and to make technical modifications necessary to do so. The
187
+ Licensor waives and/or agrees not to assert any right or
188
+ authority to forbid You from making technical modifications
189
+ necessary to exercise the Licensed Rights, including
190
+ technical modifications necessary to circumvent Effective
191
+ Technological Measures. For purposes of this Public License,
192
+ simply making modifications authorized by this Section 2(a)
193
+ (4) never produces Adapted Material.
194
+
195
+ 5. Downstream recipients.
196
+
197
+ a. Offer from the Licensor -- Licensed Material. Every
198
+ recipient of the Licensed Material automatically
199
+ receives an offer from the Licensor to exercise the
200
+ Licensed Rights under the terms and conditions of this
201
+ Public License.
202
+
203
+ b. Additional offer from the Licensor -- Adapted Material.
204
+ Every recipient of Adapted Material from You
205
+ automatically receives an offer from the Licensor to
206
+ exercise the Licensed Rights in the Adapted Material
207
+ under the conditions of the Adapter's License You apply.
208
+
209
+ c. No downstream restrictions. You may not offer or impose
210
+ any additional or different terms or conditions on, or
211
+ apply any Effective Technological Measures to, the
212
+ Licensed Material if doing so restricts exercise of the
213
+ Licensed Rights by any recipient of the Licensed
214
+ Material.
215
+
216
+ 6. No endorsement. Nothing in this Public License constitutes or
217
+ may be construed as permission to assert or imply that You
218
+ are, or that Your use of the Licensed Material is, connected
219
+ with, or sponsored, endorsed, or granted official status by,
220
+ the Licensor or others designated to receive attribution as
221
+ provided in Section 3(a)(1)(A)(i).
222
+
223
+ b. Other rights.
224
+
225
+ 1. Moral rights, such as the right of integrity, are not
226
+ licensed under this Public License, nor are publicity,
227
+ privacy, and/or other similar personality rights; however, to
228
+ the extent possible, the Licensor waives and/or agrees not to
229
+ assert any such rights held by the Licensor to the limited
230
+ extent necessary to allow You to exercise the Licensed
231
+ Rights, but not otherwise.
232
+
233
+ 2. Patent and trademark rights are not licensed under this
234
+ Public License.
235
+
236
+ 3. To the extent possible, the Licensor waives any right to
237
+ collect royalties from You for the exercise of the Licensed
238
+ Rights, whether directly or through a collecting society
239
+ under any voluntary or waivable statutory or compulsory
240
+ licensing scheme. In all other cases the Licensor expressly
241
+ reserves any right to collect such royalties, including when
242
+ the Licensed Material is used other than for NonCommercial
243
+ purposes.
244
+
245
+
246
+ Section 3 -- License Conditions.
247
+
248
+ Your exercise of the Licensed Rights is expressly made subject to the
249
+ following conditions.
250
+
251
+ a. Attribution.
252
+
253
+ 1. If You Share the Licensed Material (including in modified
254
+ form), You must:
255
+
256
+ a. retain the following if it is supplied by the Licensor
257
+ with the Licensed Material:
258
+
259
+ i. identification of the creator(s) of the Licensed
260
+ Material and any others designated to receive
261
+ attribution, in any reasonable manner requested by
262
+ the Licensor (including by pseudonym if
263
+ designated);
264
+
265
+ ii. a copyright notice;
266
+
267
+ iii. a notice that refers to this Public License;
268
+
269
+ iv. a notice that refers to the disclaimer of
270
+ warranties;
271
+
272
+ v. a URI or hyperlink to the Licensed Material to the
273
+ extent reasonably practicable;
274
+
275
+ b. indicate if You modified the Licensed Material and
276
+ retain an indication of any previous modifications; and
277
+
278
+ c. indicate the Licensed Material is licensed under this
279
+ Public License, and include the text of, or the URI or
280
+ hyperlink to, this Public License.
281
+
282
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
283
+ reasonable manner based on the medium, means, and context in
284
+ which You Share the Licensed Material. For example, it may be
285
+ reasonable to satisfy the conditions by providing a URI or
286
+ hyperlink to a resource that includes the required
287
+ information.
288
+ 3. If requested by the Licensor, You must remove any of the
289
+ information required by Section 3(a)(1)(A) to the extent
290
+ reasonably practicable.
291
+
292
+ b. ShareAlike.
293
+
294
+ In addition to the conditions in Section 3(a), if You Share
295
+ Adapted Material You produce, the following conditions also apply.
296
+
297
+ 1. The Adapter's License You apply must be a Creative Commons
298
+ license with the same License Elements, this version or
299
+ later, or a BY-NC-SA Compatible License.
300
+
301
+ 2. You must include the text of, or the URI or hyperlink to, the
302
+ Adapter's License You apply. You may satisfy this condition
303
+ in any reasonable manner based on the medium, means, and
304
+ context in which You Share Adapted Material.
305
+
306
+ 3. You may not offer or impose any additional or different terms
307
+ or conditions on, or apply any Effective Technological
308
+ Measures to, Adapted Material that restrict exercise of the
309
+ rights granted under the Adapter's License You apply.
310
+
311
+
312
+ Section 4 -- Sui Generis Database Rights.
313
+
314
+ Where the Licensed Rights include Sui Generis Database Rights that
315
+ apply to Your use of the Licensed Material:
316
+
317
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
318
+ to extract, reuse, reproduce, and Share all or a substantial
319
+ portion of the contents of the database for NonCommercial purposes
320
+ only;
321
+
322
+ b. if You include all or a substantial portion of the database
323
+ contents in a database in which You have Sui Generis Database
324
+ Rights, then the database in which You have Sui Generis Database
325
+ Rights (but not its individual contents) is Adapted Material,
326
+ including for purposes of Section 3(b); and
327
+
328
+ c. You must comply with the conditions in Section 3(a) if You Share
329
+ all or a substantial portion of the contents of the database.
330
+
331
+ For the avoidance of doubt, this Section 4 supplements and does not
332
+ replace Your obligations under this Public License where the Licensed
333
+ Rights include other Copyright and Similar Rights.
334
+
335
+
336
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
337
+
338
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
339
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
340
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
341
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
342
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
343
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
344
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
345
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
346
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
347
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
348
+
349
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
350
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
351
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
352
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
353
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
354
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
355
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
356
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
357
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
358
+
359
+ c. The disclaimer of warranties and limitation of liability provided
360
+ above shall be interpreted in a manner that, to the extent
361
+ possible, most closely approximates an absolute disclaimer and
362
+ waiver of all liability.
363
+
364
+
365
+ Section 6 -- Term and Termination.
366
+
367
+ a. This Public License applies for the term of the Copyright and
368
+ Similar Rights licensed here. However, if You fail to comply with
369
+ this Public License, then Your rights under this Public License
370
+ terminate automatically.
371
+
372
+ b. Where Your right to use the Licensed Material has terminated under
373
+ Section 6(a), it reinstates:
374
+
375
+ 1. automatically as of the date the violation is cured, provided
376
+ it is cured within 30 days of Your discovery of the
377
+ violation; or
378
+
379
+ 2. upon express reinstatement by the Licensor.
380
+
381
+ For the avoidance of doubt, this Section 6(b) does not affect any
382
+ right the Licensor may have to seek remedies for Your violations
383
+ of this Public License.
384
+
385
+ c. For the avoidance of doubt, the Licensor may also offer the
386
+ Licensed Material under separate terms or conditions or stop
387
+ distributing the Licensed Material at any time; however, doing so
388
+ will not terminate this Public License.
389
+
390
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
391
+ License.
392
+
393
+
394
+ Section 7 -- Other Terms and Conditions.
395
+
396
+ a. The Licensor shall not be bound by any additional or different
397
+ terms or conditions communicated by You unless expressly agreed.
398
+
399
+ b. Any arrangements, understandings, or agreements regarding the
400
+ Licensed Material not stated herein are separate from and
401
+ independent of the terms and conditions of this Public License.
402
+
403
+
404
+ Section 8 -- Interpretation.
405
+
406
+ a. For the avoidance of doubt, this Public License does not, and
407
+ shall not be interpreted to, reduce, limit, restrict, or impose
408
+ conditions on any use of the Licensed Material that could lawfully
409
+ be made without permission under this Public License.
410
+
411
+ b. To the extent possible, if any provision of this Public License is
412
+ deemed unenforceable, it shall be automatically reformed to the
413
+ minimum extent necessary to make it enforceable. If the provision
414
+ cannot be reformed, it shall be severed from this Public License
415
+ without affecting the enforceability of the remaining terms and
416
+ conditions.
417
+
418
+ c. No term or condition of this Public License will be waived and no
419
+ failure to comply consented to unless expressly agreed to by the
420
+ Licensor.
421
+
422
+ d. Nothing in this Public License constitutes or may be interpreted
423
+ as a limitation upon, or waiver of, any privileges and immunities
424
+ that apply to the Licensor or You, including from the legal
425
+ processes of any jurisdiction or authority.
426
+
427
+ =======================================================================
428
+
429
+ Creative Commons is not a party to its public
430
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
431
+ its public licenses to material it publishes and in those instances
432
+ will be considered the “Licensor.” The text of the Creative Commons
433
+ public licenses is dedicated to the public domain under the CC0 Public
434
+ Domain Dedication. Except for the limited purpose of indicating that
435
+ material is shared under a Creative Commons public license or as
436
+ otherwise permitted by the Creative Commons policies published at
437
+ creativecommons.org/policies, Creative Commons does not authorize the
438
+ use of the trademark "Creative Commons" or any other trademark or logo
439
+ of Creative Commons without its prior written consent including,
440
+ without limitation, in connection with any unauthorized modifications
441
+ to any of its public licenses or any other arrangements,
442
+ understandings, or agreements concerning use of licensed material. For
443
+ the avoidance of doubt, this paragraph does not form part of the
444
+ public licenses.
445
+
446
+ Creative Commons may be contacted at creativecommons.org.
NOTICE ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PUMP
2
+ Copyright (c) 2022-present NAVER Corp.
3
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license
4
+
5
+ --------------------------------------------------------------------------------------
6
+
7
+ This project contains subcomponents with separate copyright notices and license terms.
8
+ Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
9
+
10
+ =====
11
+
12
+ pytorch/vision
13
+ https://github.com/pytorch/vision
14
+
15
+
16
+ BSD 3-Clause License
17
+
18
+ Copyright (c) Soumith Chintala 2016,
19
+ All rights reserved.
20
+
21
+ Redistribution and use in source and binary forms, with or without
22
+ modification, are permitted provided that the following conditions are met:
23
+
24
+ * Redistributions of source code must retain the above copyright notice, this
25
+ list of conditions and the following disclaimer.
26
+
27
+ * Redistributions in binary form must reproduce the above copyright notice,
28
+ this list of conditions and the following disclaimer in the documentation
29
+ and/or other materials provided with the distribution.
30
+
31
+ * Neither the name of the copyright holder nor the names of its
32
+ contributors may be used to endorse or promote products derived from
33
+ this software without specific prior written permission.
34
+
35
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
36
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
37
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
38
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
39
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
40
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
41
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
42
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
43
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
44
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
45
+
46
+ =====
README.md ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: PUMP
3
+ emoji: 📚
4
+ colorFrom: yellow
5
+ colorTo: red
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
10
+
11
+ # PUMP: pyramidal and uniqueness matching priors for unsupervised learning of local features #
12
+ ![image](imgs/teaser_paper.jpg)
13
+
14
+ Official repository for the following [paper](https://europe.naverlabs.com/research/publications/pump-pyramidal-and-uniqueness-matching-priors-for-unsupervised-learning-of-local-features/):
15
+
16
+ ```text
17
+ @inproceedings{cvpr22_pump,
18
+ author = {Jerome Revaud, Vincent Leroy, Philippe Weinzaepfel, Boris Chidlovskii},
19
+ title = {PUMP: pyramidal and uniqueness matching priors for unsupervised learning of local features},
20
+ booktitle = {CVPR},
21
+ year = {2022},
22
+ }
23
+ ```
24
+ ![image](imgs/overview.png)
25
+
26
+ License
27
+ -------
28
+ Our code is released under the CC BY-NC-SA 4.0 License (see [LICENSE](LICENSE) for more details), available only for non-commercial use.
29
+
30
+
31
+ Requirements
32
+ ------------
33
+ - Python 3.8+ equipped with standard scientific packages and PyTorch / TorchVision:
34
+ ```
35
+ tqdm >= 4
36
+ PIL >= 8.1.1
37
+ numpy >= 1.19
38
+ scipy >= 1.6
39
+ torch >= 1.10.0
40
+ torchvision >= 0.9.0
41
+ matplotlib >= 3.3.4
42
+ ```
43
+ - the CUDA tool kit, to compile custom CUDA kernels
44
+ ```bash
45
+ cd core/cuda_deepm/
46
+ python setup.py install
47
+ ```
48
+
49
+ Warping Demo
50
+ ------------
51
+
52
+ ```bash
53
+ python demo_warping.py
54
+ ```
55
+
56
+ You should see the following result:
57
+ ![image](imgs/demo_warp.jpg)
58
+
59
+ Test usage
60
+ ----------
61
+
62
+ We provide 4 variations of the pairwise matching code, named `test_xxxscale_yyy.py`:
63
+ - xxx: `single`-scale or `multi`-scale.
64
+ Single-scale can cope with 0.75~1.33x scale difference at most.
65
+ Multi-scale version can also be rotation invariant if asked.
66
+ - yyy: recursive or not. Recursive is slower but provide denser/better outputs.
67
+
68
+ For most cases, you want to use `test_multiscale.py`:
69
+ ```bash
70
+ python test_multiscale.py
71
+ --img1 path/to/img1
72
+ --img2 path/to/img2
73
+ --resize 600 # important, see below
74
+ --post-filter
75
+ --output path/to/correspondences.npy
76
+ ```
77
+
78
+ It outputs a numpy binary file with the field `file_data['corres']` containing a list of correspondences.
79
+ The row format is `[x1, y1, x2, y2, score, scale_rot_code]`.
80
+ Use `core.functional.decode_scale_rot(code) --> (scale, angle_in_degrees)` to decode the `scale_rot_code`.
81
+
82
+
83
+ #### Optional parameters:
84
+
85
+ - **Prior image resize**: `--resize SIZE`
86
+
87
+ This is a very important parameter. In general, the bigger, the better (and slower).
88
+ Be wary that the memory footprint explodes with the image size.
89
+ Here is the table of maximum `--resize` values depending on the image aspect-ratio:
90
+
91
+ | Aspect-ratio | Example img sizes | GPU memory | resize |
92
+ |--------------|--------------------|------------|--------|
93
+ | 4/3 | 800x600, 1024x768 | 16 Go | 600 |
94
+ | 4/3 | 800x600, 1024x768 | 22 Go | 680 |
95
+ | 4/3 | 800x600, 1024x768 | 32 Go | 760 |
96
+ | 1/1 | 1024x1024 | 16 Go | 540 |
97
+ | 1/1 | 1024x1024 | 22 Go | 600 |
98
+ | 1/1 | 1024x1024 | 32 Go | 660 |
99
+
100
+ (Formula: `memory_in_bytes = (W1*H1*W2*H2)*1.333*2/16`)
101
+
102
+ - **Base descriptor**: `--desc {PUMP, PUMP-stytrf}`
103
+
104
+ We provide the `PUMP` descriptor from our paper, as well as `PUMP-stytrf` (with additional style-transfer training).
105
+ Defaults to `PUMP-stytrf`.
106
+
107
+ - **Scale**: `--max-scale SCALE`
108
+
109
+ By default, this value is set to 4, meaning that PUMP is _at least_ invariant to a 4x zoom-in or
110
+ zoom-out. In practically all cases, this is more than enough. You may reduce this value if you know
111
+ this is too much in order to accelerate computations.
112
+
113
+ - **Rotation**: `--max-rot DEGREES`
114
+
115
+ By default, PUMP is not rotation-invariant. To enforce rotation invariance, you need to specify
116
+ the amount of rotation it can tolerate. The more, the slower. Maximum value is 180.
117
+ If you know that images are not vertically oriented, you can just use 90 degrees.
118
+
119
+ - **post-filter**: `--post-filter "option1=val1,option2=val2,..."`
120
+
121
+ When activated, post-filtering remove spurious correspondences based on their local consistency.
122
+ See `python post_filter.py --help` for details about the possible options.
123
+ It is geometry-agnostic and naturally supports dynamic scenes.
124
+ If you want to output _pixel-dense_ correspondences (a.k.a _optical flow_), you need to post-process
125
+ the correspondences with `--post-filter densify=True`. See `demo_warping.py` for an example.
126
+
127
+
128
+ #### Visualization of results:
129
+ ```bash
130
+ python -m tools.viz --img1 path/to/img1 --img2 path/to/img2 --corres path/to/correspondences.npy
131
+ ```
132
+
133
+ Reproducing results on the ETH-3D dataset
134
+ -----------------------------------------
135
+
136
+ 1. Download the ETH-3D dataset from [their website](https://www.eth3d.net/datasets) and extract it in `datasets/eth3d/`
137
+
138
+ 2. Run the code `python run_ETH3D.py`. You should get results slightly better than reported in the paper.
139
+
140
+
141
+ Training PUMP from scratch
142
+ --------------------------
143
+
144
+ 1. Download the training data with
145
+ ```bash
146
+ bash download_training_data.sh
147
+ ```
148
+
149
+ This consists of web images from [this paper](http://cmp.felk.cvut.cz/revisitop/) for the self-supervised loss (as in [R2D2](https://github.com/naver/r2d2))
150
+ and image pairs from the [SfM120k dataset](http://cmp.felk.cvut.cz/cnnimageretrieval/) with automatically
151
+ extracted pixel correspondences. Note that correspondences are *not* used in the loss, since the loss is
152
+ unsupervised. They are only necessary so that random cropping produces pairs of crops at least partially aligned.
153
+ Therefore, correspondences do not need to be 100% correct or even pixel-precise.
154
+
155
+ 2. Run `python train.py --save-path <output_dir>/`
156
+
157
+ Note that the training code is quite rudimentary (only supports `nn.DataParallel`,
158
+ no support for `DataDistributed` at the moment, and no validation phase neither).
159
+
160
+ 3. Move and rename your final checkpoint to `checkpoints/NAME.pt` and test it with
161
+ ```bash
162
+ python test_multiscale.py ... --desc NAME
163
+ ```
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sys, os
3
+ import torch
4
+ import matplotlib.pylab as plt
5
+
6
+ def pump_matching(img1, img2, trained_with_st=False, scale=300, max_scale=1, max_rot=0, use_gpu=False):
7
+
8
+ use_singlescale = max_scale==1 and max_rot==0
9
+ if use_singlescale: # single
10
+ from test_singlescale import Main, arg_parser
11
+ else:
12
+ from test_multiscale import Main, arg_parser
13
+ parser = arg_parser()
14
+
15
+ args_list = ['--img1','dummy','--img2','dummy','--post-filter', '--desc','PUMP-stytrf' if trained_with_st else 'PUMP','--resize',str(scale)]
16
+ if not use_gpu:
17
+ args_list += ['--device', 'cpu']
18
+ if not use_singlescale:
19
+ args_list += ['--max-scale',str(max_scale),'--max-rot',str(max_rot)]
20
+
21
+ args = parser.parse_args(args_list)
22
+
23
+ corres = Main().run_from_args_with_images(img1, img2, args)
24
+
25
+ fig1 = plt.figure(1)
26
+ plt.imshow(img1)
27
+ ax1 = plt.gca()
28
+ ax1.axis('off')
29
+ plt.tight_layout()
30
+
31
+ fig2 = plt.figure(2)
32
+ plt.imshow(img2)
33
+ ax2 = plt.gca()
34
+ ax2.axis('off')
35
+ plt.tight_layout()
36
+
37
+ from tools.viz import plot_grid
38
+ if corres.shape[-1] > 4:
39
+ corres = corres[corres[:,4]>0,:] # select non-null correspondences
40
+ if corres.shape[0]>0: plot_grid(corres, ax1, ax2, marker='+')
41
+
42
+ img1 = None
43
+ img2 = None
44
+
45
+ return fig1, fig2
46
+
47
+ has_cuda = torch.cuda.is_available() and torch.cuda.device_count()>0
48
+
49
+ title = "PUMP local descriptor demo"
50
+ description = "This is a visualization demo for the PUMP local descriptors presented in our CVPR 2022 paper <b><a href='https://europe.naverlabs.com/research/publications/pump-pyramidal-and-uniqueness-matching-priors-for-unsupervised-learning-of-local-features/' target='_blank'>PUMP: Pyramidal and Uniqueness Matching Priors for Unsupervised Learning of Local Features</a></b>.</p><p><b>WARNING:</b> this demo runs on cpus with downscaled images, without multi-scale or multi-rotations testing, due to limited memory and computational resources, please check out our <a href='https://github.com/naver/pump' target='_blank'>original github repo</a> for these features.</p>"
51
+
52
+ article = "<p style='text-align: center'><a href='https://github.com/naver/pump' target='_blank'>Original Github Repo</a></p>"
53
+
54
+ iface = gr.Interface(
55
+ fn=pump_matching,
56
+ inputs=[
57
+ gr.inputs.Image(shape=(1024, 1024), type="pil", label="First Image"),
58
+ gr.inputs.Image(shape=(1024, 1024), type="pil", label="Second Image"),
59
+ gr.inputs.Checkbox(default=False, label="Use the model trained with style transfer"),
60
+ #gr.inputs.Slider(minimum=300, maximum=600, default=400, step=1, label="Original test scale"),
61
+ #gr.inputs.Slider(minimum=1, maximum=4, default=1, step=0.1, label="Multi Scale Testing - maximum scale (makes it slower)"),
62
+ #gr.inputs.Slider(minimum=0, maximum=180, default=0, step=45, label="Multi Rotation Testing - max rot (makes it slower)"),]
63
+ #+ ([gr.inputs.Checkbox(default=True, label='Use GPU instead of CPU')] if has_cuda else []),"""
64
+ ],
65
+ outputs=[
66
+ gr.outputs.Image(type="plot", label="Matches in the first image"),
67
+ gr.outputs.Image(type="plot", label="Matches in the second image"),
68
+ ],
69
+ title=title,
70
+ theme='peach',
71
+ description=description,
72
+ article=article,
73
+ examples=[
74
+ ['datasets/demo_warp/mountains_src.jpg','datasets/demo_warp/mountains_tgt.jpg',False],#,400,1,0]+([True] if has_cuda else []),
75
+ ]
76
+ )
77
+ iface.launch(enable_queue=True)
checkpoints/PUMP-stytrf.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e78a9bbbd8a6c9823265adf41b4a330f87fa58fb07832d6d56c6ae94769fd27d
3
+ size 13976029
checkpoints/PUMP.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a58cf5a1a4699e087c269ec9054c35637cd056fc68a37f1ee96da6b53e0804f
3
+ size 13976029
core/conv_mixer.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ """ From the ICLR22 paper: Patches are all you need
12
+ https://openreview.net/pdf?id=TVHS5Y4dNvM
13
+ """
14
+
15
+ class Residual(nn.Module):
16
+ def __init__(self, fn, stride=1):
17
+ super().__init__()
18
+ self.fn = fn
19
+ self.stride = stride
20
+
21
+ def forward(self, x):
22
+ s = slice(None,None,self.stride)
23
+ return x[:,:,s,s] + self.fn(x)[:,:,s,s]
24
+
25
+
26
+ class ConvMixer (nn.Sequential):
27
+ """ Modified ConvMixer with convolutional layers at the bottom.
28
+
29
+ From the ICLR22 paper: Patches are all you need, https://openreview.net/pdf?id=TVHS5Y4dNvM
30
+ """
31
+ def __init__(self, output_dim, hidden_dim,
32
+ depth=None, kernel_size=5, patch_size=8, group_size=1,
33
+ preconv=1, faster=True, relu=nn.ReLU):
34
+
35
+ assert kernel_size % 2 == 1, 'kernel_size must be odd'
36
+ output_step = 1 + faster
37
+ assert patch_size % output_step == 0, f'patch_size must be multiple of {output_step}'
38
+ self.patch_size = patch_size
39
+
40
+ hidden_dims = [hidden_dim//4]*preconv + [hidden_dim]*(depth+1)
41
+ ops = [
42
+ nn.Conv2d(3, hidden_dims[0], kernel_size=5, padding=2),
43
+ relu(),
44
+ nn.BatchNorm2d(hidden_dims[0])]
45
+
46
+ for _ in range(1,preconv):
47
+ ops += [
48
+ nn.Conv2d(hidden_dims.pop(0), hidden_dims[0], kernel_size=3, padding=1),
49
+ relu(),
50
+ nn.BatchNorm2d(hidden_dims[0])]
51
+
52
+ ops += [
53
+ nn.Conv2d(hidden_dims.pop(0), hidden_dims[0], kernel_size=patch_size, stride=patch_size),
54
+ relu(),
55
+ nn.BatchNorm2d(hidden_dims[0])]
56
+
57
+ for idim, odim in zip(hidden_dims[0:], hidden_dims[1:]):
58
+ ops += [Residual(nn.Sequential(
59
+ nn.Conv2d(idim, idim, kernel_size, groups=max(1,idim//group_size), padding=kernel_size//2),
60
+ relu(),
61
+ nn.BatchNorm2d(idim)
62
+ )),
63
+ nn.Conv2d(idim, odim, kernel_size=1),
64
+ relu(),
65
+ nn.BatchNorm2d(odim)]
66
+ ops += [
67
+ nn.Conv2d(odim, output_dim*(patch_size//output_step)**2, kernel_size=1),
68
+ nn.PixelShuffle( patch_size//output_step ),
69
+ nn.Upsample(scale_factor=output_step, mode='bilinear', align_corners=False)]
70
+
71
+ super().__init__(*ops)
72
+
73
+ def forward(self, img):
74
+ assert img.ndim == 4
75
+ B, C, H, W = img.shape
76
+ desc = super().forward(img)
77
+ return F.normalize(desc, dim=-3)
78
+
79
+
80
+ if __name__ == '__main__':
81
+ net = ConvMixer3(128, 512, 7, patch_size=4, kernel_size=9)
82
+ print(net)
83
+
84
+ img = torch.rand(2,3,256,256)
85
+ print('input.shape =', img.shape)
86
+ desc = net(img)
87
+ print('desc.shape =', desc.shape)
core/cuda_deepm/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.so
2
+ _ext*
3
+ __pycache__
4
+ build
core/cuda_deepm/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ # run `python setup.py install`
6
+ import cuda_deepm as _kernels
7
+
8
+ __all__ = {k:v for k,v in vars(_kernels).items() if k[0] != '_'}
9
+ globals().update(__all__)
core/cuda_deepm/func.cpp ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2022-present NAVER Corp.
2
+ // CC BY-NC-SA 4.0
3
+ // Available only for non-commercial use
4
+
5
+ #include <torch/extension.h>
6
+ using namespace torch::indexing; // Slice
7
+ #include <vector>
8
+
9
+ #define MIN(x, y) ((x) < (y) ? (x) : (y))
10
+ #define MAX(x, y) ((x) < (y) ? (y) : (x))
11
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
12
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
14
+
15
+ inline Slice sl(bool x) {
16
+ if (x)
17
+ return Slice(0, -1);
18
+ else
19
+ return Slice(1, None);
20
+ }
21
+
22
+ torch::Tensor forward_agg_cuda( int level, float norm, const torch::Tensor lower,
23
+ const at::optional<at::Tensor> weights, torch::Tensor upper );
24
+
25
+ std::vector<torch::Tensor> forward_agg( int level, float norm, const torch::Tensor lower,
26
+ const at::optional<at::Tensor> weights = at::nullopt ) {
27
+ TORCH_CHECK(level >= 1, "level must be >= 1");
28
+ TORCH_CHECK(lower.dim() == 4, "input must have 4 dimensions");
29
+ const auto LH1 = lower.size(0);
30
+ const auto LW1 = lower.size(1);
31
+ const auto LH2 = lower.size(2);
32
+ const auto LW2 = lower.size(3);
33
+ if (weights) TORCH_CHECK(weights->size(0) == LH1 && weights->size(1) == LW1, "weights should have shape == lower.shape[:2]");
34
+ const auto UH1 = (level == 1) ? LH1+1 : LH1;
35
+ const auto UW1 = (level == 1) ? LW1+1 : LW1;
36
+
37
+ TORCH_CHECK(lower.is_cuda())
38
+ auto upper = torch::zeros({UH1, UW1, LH2, LW2}, lower.options());
39
+ torch::Tensor new_weights = forward_agg_cuda( level, norm, lower, weights, upper );
40
+ return {upper, new_weights};
41
+ }
42
+
43
+
44
+ torch::Tensor forward_pool_agg_cuda( int level, float norm, const torch::Tensor lower,
45
+ const at::optional<at::Tensor> weights, torch::Tensor upper );
46
+
47
+ std::vector<torch::Tensor> forward_pool_agg( int level, float norm, const torch::Tensor lower,
48
+ const at::optional<at::Tensor> weights = at::nullopt ) {
49
+ TORCH_CHECK(level >= 1, "level must be >= 1");
50
+ TORCH_CHECK(lower.dim() == 4, "input must have 4 dimensions");
51
+ const auto LH1 = lower.size(0);
52
+ const auto LW1 = lower.size(1);
53
+ const auto LH2 = lower.size(2);
54
+ const auto LW2 = lower.size(3);
55
+ if (weights) TORCH_CHECK(weights->size(0) == LH1 && weights->size(1) == LW1, "weights should have shape == lower.shape[:2]");
56
+ const auto UH1 = (level == 1) ? LH1+1 : LH1;
57
+ const auto UW1 = (level == 1) ? LW1+1 : LW1;
58
+
59
+ TORCH_CHECK(lower.is_cuda())
60
+ auto upper = torch::zeros({UH1, UW1, 1+(LH2-1)/2, 1+(LW2-1)/2}, lower.options());
61
+ torch::Tensor new_weights = forward_pool_agg_cuda( level, norm, lower, weights, upper );
62
+ return {upper, new_weights};
63
+ }
64
+
65
+ // forward declaration
66
+ void backward_agg_unpool_cuda( int level, const torch::Tensor upper, torch::Tensor lower, bool exclude_borders );
67
+
68
+ void backward_agg_unpool( int level, const torch::Tensor upper, torch::Tensor lower, bool exclude_borders = true ) {
69
+ TORCH_CHECK(level >= 1, "level must be >= 1");
70
+ TORCH_CHECK( upper.dim() == 4 && lower.dim() == 4, "inputs should be 4-dimensional" );
71
+
72
+ TORCH_CHECK(upper.is_cuda() && lower.is_cuda())
73
+ backward_agg_unpool_cuda(level, upper, lower, exclude_borders);
74
+ }
75
+
76
+
77
+ void max_pool3d_cuda( const torch::Tensor tensor, const int kernel_size, const int stride,
78
+ torch::Tensor maxima, torch::Tensor indices );
79
+
80
+ std::vector<torch::Tensor> max_pool3d( const torch::Tensor tensor, const int kernel_size, const int stride ) {
81
+ TORCH_CHECK(tensor.dim() == 4, "tensor should be 4-dimensional: BxCxHxW");
82
+ TORCH_CHECK( 1 <= kernel_size, "bad kernel size %d", kernel_size );
83
+ TORCH_CHECK( 1 <= stride, "bad stride %d", stride );
84
+ const int IB = tensor.size(0);
85
+ const int IH = tensor.size(2); // input height
86
+ const int IW = tensor.size(3); // input width
87
+
88
+ // output size
89
+ const int OH = 1 + (IH - kernel_size) / stride;
90
+ const int OW = 1 + (IW - kernel_size) / stride;
91
+
92
+ torch::Tensor maxima = torch::empty({IB, OH, OW}, tensor.options());
93
+ torch::Tensor indices = torch::empty({IB, OH, OW}, tensor.options().dtype(torch::kInt64));
94
+
95
+ if (tensor.is_cuda())
96
+ max_pool3d_cuda( tensor, kernel_size, stride, maxima, indices );
97
+ else
98
+ TORCH_CHECK(false, "CPU max_pool3d not implemented yet");
99
+ return {maxima, indices};
100
+ }
101
+
102
+ static inline float ptdot( const float* m, float x, float y ) {
103
+ return x*m[0] + y*m[1] + m[2];
104
+ }
105
+
106
+ static inline float pow2(float v) {
107
+ return v*v;
108
+ }
109
+
110
+ void merge_corres_cpu( const torch::Tensor corres, int offset, const torch::Tensor _inv_rot,
111
+ float dmax, torch::Tensor all_corres, const int all_step ) {
112
+ const int H = corres.size(0);
113
+ const int W = corres.size(1);
114
+ const float tol = 2*2; // squared
115
+ dmax *= dmax; // squared
116
+
117
+ TORCH_CHECK( _inv_rot.is_contiguous() );
118
+ const float* inv_rot = _inv_rot.data_ptr<float>();
119
+
120
+ auto corres_a = corres.accessor<float,3>();
121
+ auto all_corres_a = all_corres.accessor<float,3>();
122
+
123
+ // for each bin of the final histograms, we get the nearest-neighbour bin in corres0 and corres1
124
+ for (int j=0; j<all_corres.size(0); j++)
125
+ for (int i=0; i<all_corres.size(1); i++) {
126
+ // printf("accessing all_corres[%d,%d]", j, i);
127
+ auto all_cor = all_corres_a[j][i];
128
+
129
+ // center of the bin in the reference frame
130
+ float x = i*all_step + all_step/2;
131
+ float y = j*all_step + all_step/2;
132
+ // printf(" -> (%g,%g) in ref img", x, y);
133
+
134
+ // center of the bin on the rescaled+rotated image
135
+ float xr = ptdot( inv_rot + 0, x, y );
136
+ float yr = ptdot( inv_rot + 3, x, y );
137
+ // printf(" -> (%g,%g) in rescaled", xr, yr);
138
+
139
+ // iterate on the nearby bins
140
+ int xb = (int)(0.5+ xr/4); // rescaled+rotated desc always has step 4
141
+ int yb = (int)(0.5+ yr/4);
142
+ // printf(" -> (%d,%d) in bins\n", xb, yb);
143
+
144
+ float best = dmax;
145
+ for (int v = MAX(0,yb-1); v <= MIN(H,yb+1); v++)
146
+ for (int u = MAX(0,xb-1); u <= MIN(W,xb+1); u++) {
147
+ // assert( v >= 0 && v < corres_a.size(0) );
148
+ // assert( u >= 0 && u < corres_a.size(1) );
149
+ auto cor = corres_a[v][u];
150
+ float d = pow2(cor[offset]-x) + pow2(cor[offset+1]-y);
151
+ if( d < best ) best = d;
152
+ }
153
+
154
+ for (int v = MAX(0,yb-1); v <= MIN(H,yb+1); v++)
155
+ for (int u = MAX(0,xb-1); u <= MIN(W,xb+1); u++) {
156
+ // assert( v >= 0 && v < corres_a.size(0) );
157
+ // assert( u >= 0 && u < corres_a.size(1) );
158
+ auto cor = corres_a[v][u];
159
+ float d = pow2(cor[offset]-x) + pow2(cor[offset+1]-y);
160
+ if( d <= tol*best ) { // spatially close
161
+ // merge correspondence if score is better than actual
162
+ // printf("update all_corres[%d,%d]\n", v,u);
163
+ if( cor[4] > all_cor[4] )
164
+ for (int k = 0; k < all_corres.size(2); k++)
165
+ all_cor[k] = cor[k];
166
+ }
167
+ }
168
+ }
169
+ }
170
+
171
+ void merge_corres_cuda( const torch::Tensor corres, int offset, const torch::Tensor inv_rot,
172
+ float dmax, torch::Tensor all_corres, const int all_step );
173
+
174
+ void merge_corres( const torch::Tensor corres, int offset, const torch::Tensor rot,
175
+ torch::Tensor all_corres, const int all_step ) {
176
+ TORCH_CHECK( corres.dim() == 3 && corres.size(2) == 6, "corres.shape should be (H,W,6)" );
177
+ TORCH_CHECK( all_corres.dim() == 3 && all_corres.size(2) == 6, "all_corres.shape should be (H,W,6)" );
178
+
179
+ float dmax = 8 * torch::sqrt(torch::det(rot)).item<float>();
180
+ torch::Tensor inv_rot = torch::inverse(rot).contiguous();
181
+
182
+ if (all_corres.is_cuda())
183
+ merge_corres_cuda( corres, offset, inv_rot, dmax, all_corres, all_step );
184
+ else
185
+ merge_corres_cpu( corres, offset, inv_rot, dmax, all_corres, all_step );
186
+ }
187
+
188
+
189
+ void mask_correlations_radial_cuda( torch::Tensor corr, const torch::Tensor targets,
190
+ const float radius, const float alpha);
191
+
192
+ void mask_correlations_radial( torch::Tensor corr, const torch::Tensor targets,
193
+ const float radius, const float alpha) {
194
+ // radius: protected area in pixels around each target center
195
+ // alpha: in [0,1]. If alpha = 0: no effect. If alpha = 1: full effect.
196
+ TORCH_CHECK( corr.dim() == 4 );
197
+ TORCH_CHECK( targets.dim() == 3 );
198
+ TORCH_CHECK( targets.size(0) == corr.size(0) && targets.size(1) == corr.size(1) && targets.size(2) == 2,
199
+ "correlations and targets should have the same shape[:2]" );
200
+
201
+ if (corr.is_cuda())
202
+ mask_correlations_radial_cuda( corr, targets, radius, alpha );
203
+ else
204
+ TORCH_CHECK(false, "TODO");
205
+ }
206
+
207
+
208
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
209
+ m.def("forward_agg", &forward_agg, "forward aggregation (CUDA)");
210
+ m.def("forward_pool_agg", &forward_pool_agg, "forward pooling and aggregation (CUDA)");
211
+ m.def("backward_agg_unpool", &backward_agg_unpool, "backward sparse-conv and max-unpooling (C++ & CUDA)");
212
+ m.def("max_pool3d", &max_pool3d, "max_pool3d that can handle big inputs (CUDA)");
213
+ m.def("merge_corres_one_side", &merge_corres, "merge correspondences on CPU or GPU" );
214
+ m.def("mask_correlations_radial", &mask_correlations_radial, "mask correlations radially (CUDA)" );
215
+ }
core/cuda_deepm/kernels.cu ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2022-present NAVER Corp.
2
+ // CC BY-NC-SA 4.0
3
+ // Available only for non-commercial use
4
+
5
+ #include <torch/extension.h>
6
+ #include <cuda.h>
7
+ #include <cuda_runtime.h>
8
+ #include <vector>
9
+
10
+ #define MIN(x, y) ((x) < (y) ? (x) : (y))
11
+ #define MAX(x, y) ((x) < (y) ? (y) : (x))
12
+ #define inf std::numeric_limits<float>::infinity()
13
+
14
+ #define CHECK_CUDA(tensor) {\
15
+ TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \
16
+ TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); }
17
+ void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));}
18
+
19
+
20
+ #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 600
21
+ #define atomicMax_block atomicMax
22
+ #endif
23
+
24
+
25
+ template <typename scalar_t>
26
+ __global__ void forward_agg_cuda_kernel(
27
+ const int LH1, const int LW1, const int LH2, const int LW2,
28
+ const int gap_left, const int gap_right, float norm,
29
+ const torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> lower,
30
+ torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> upper,
31
+ const float* weights, float* new_weights ) {
32
+
33
+ const auto UH1 = LH1 + bool(!gap_left); // level 0 is smaller than other levels
34
+ const auto UW1 = LW1 + bool(!gap_left);
35
+ const auto UH2 = LH2;
36
+ const auto UW2 = LW2;
37
+
38
+ int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
39
+ const int uw2 = idx % UW2; idx /= UW2;
40
+ const int uh2 = idx % UH2; idx /= UH2;
41
+ const int uw1 = idx % UW1; idx /= UW1;
42
+ const int uh1 = idx;
43
+ if (uh1 >= UH1) return;
44
+
45
+ // then, add the 4 child
46
+ float sumw = 0, nrm = 0, res = 0;
47
+ // #pragma unroll
48
+ for (int i = 0; i < 4; i++) {
49
+ const int v = i/2, u = i%2;
50
+ // source pixel
51
+ const int lh1 = uh1 + (1-v) * gap_left - v * gap_right;
52
+ if (lh1 < 0 || lh1 >= LH1) continue;
53
+ const int lw1 = uw1 + (1-u) * gap_left - u * gap_right;
54
+ if (lw1 < 0 || lw1 >= LW1) continue;
55
+
56
+ // load weight even if (lh2,lw2) are invalid
57
+ const float weight = weights ? weights[lh1*LW1 + lw1] : 1;
58
+ sumw += weight;
59
+
60
+ const int lh2 = uh2 + 1 - 2*v;
61
+ if (lh2 < 0 || lh2 >= LH2) continue;
62
+ const int lw2 = uw2 + 1 - 2*u;
63
+ if (lw2 < 0 || lw2 >= LW2) continue;
64
+
65
+ res += weight * lower[lh1][lw1][lh2][lw2];
66
+ nrm += weight;
67
+ }
68
+
69
+ // normalize output
70
+ nrm = sumw * (nrm < sumw ? powf(nrm/sumw, norm) : 1);
71
+ upper[uh1][uw1][uh2][uw2] = (nrm ? res / nrm : 0);
72
+ if (uh2 == 1 && uw2 == 1)
73
+ new_weights[uh1*UW1 + uw1] = sumw;
74
+ }
75
+
76
+ torch::Tensor forward_agg_cuda( int level, float norm, const torch::Tensor lower,
77
+ const at::optional<at::Tensor> weights, torch::Tensor upper ) {
78
+ CHECK_CUDA(lower);
79
+ CHECK_CUDA(upper);
80
+ if (weights) CHECK_CUDA(weights.value());
81
+
82
+ const auto UH1 = upper.size(0);
83
+ const auto UW1 = upper.size(1);
84
+ const auto UH2 = upper.size(2);
85
+ const auto UW2 = upper.size(3);
86
+ const auto LH1 = lower.size(0);
87
+ const auto LW1 = lower.size(1);
88
+ const auto LH2 = lower.size(2);
89
+ const auto LW2 = lower.size(3);
90
+ TORCH_CHECK( UH1 == LH1 + int(level==1) && UW1 == LW1 + int(level==1), "inconsistent lower and upper shapes" );
91
+
92
+ const int gap_left = (level >= 2) ? 1 << (level-2) : 0; // 0, 1, 2, 4, ...
93
+ const int gap_right= 1 << MAX(0, level-2); // 1, 1, 2, 4, ...
94
+
95
+ const int MAX_THREADS = 512; // faster than 1024 (higher SM occupancy)
96
+ const int THREADS_PER_BLOCK = MAX_THREADS;
97
+ const int N_BLOCKS = (UH1*UW1*UH2*UW2 + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
98
+
99
+ torch::Tensor new_weights = torch::zeros({UH1, UW1}, upper.options().dtype(torch::kFloat32));
100
+
101
+ // one block for each layer, one thread per local-max
102
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(lower.type(), "forward_agg_cuda", ([&] {
103
+ forward_agg_cuda_kernel<<<N_BLOCKS, THREADS_PER_BLOCK>>>(
104
+ LH1, LW1, LH2, LW2,
105
+ gap_left, gap_right, norm,
106
+ lower.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
107
+ upper.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
108
+ weights ? weights->data_ptr<float>() : nullptr, new_weights.data_ptr<float>() );
109
+ }));
110
+ return new_weights;
111
+ }
112
+
113
+ template <typename scalar_t>
114
+ __global__ void forward_pool_agg_cuda_kernel(
115
+ const int LH1, const int LW1, const int LH2, const int LW2,
116
+ // const int UH1, const int UW1, const int UH2, const int UW2,
117
+ const int gap_left, const int gap_right, float norm,
118
+ const torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> lower,
119
+ torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> upper,
120
+ const float* weights, float* new_weights ) {
121
+
122
+ const auto UH1 = LH1 + bool(!gap_left); // level 0 is smaller than other levels
123
+ const auto UW1 = LW1 + bool(!gap_left);
124
+ const auto UH2 = (LH2-1)/2 + 1;
125
+ const auto UW2 = (LW2-1)/2 + 1;
126
+
127
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
128
+ const int uw2 = idx % UW2; idx /= UW2;
129
+ const int uh2 = idx % UH2; idx /= UH2;
130
+ const int uw1 = idx % UW1; idx /= UW1;
131
+ const int uh1 = idx;
132
+ if (uh1 >= UH1) return;
133
+
134
+ // then, add the 4 child
135
+ float sumw = 0, nrm = 0, res = 0;
136
+ // #pragma unroll
137
+ for (int i = 0; i < 4; i++) {
138
+ const int v = i/2, u = i%2;
139
+ // source pixel
140
+ const int lh1 = uh1 + (1-v) * gap_left - v * gap_right;
141
+ if (lh1 < 0 || lh1 >= LH1) continue;
142
+ const int lw1 = uw1 + (1-u) * gap_left - u * gap_right;
143
+ if (lw1 < 0 || lw1 >= LW1) continue;
144
+
145
+ // load weight even if (lh2,lw2) are invalid
146
+ const float weight = weights ? weights[lh1*LW1 + lw1] : 1;
147
+ sumw += weight;
148
+
149
+ const int lh2_ = 2*(uh2 + 1 - 2*v); // position in lower
150
+ const int lw2_ = 2*(uw2 + 1 - 2*u);
151
+ float lower_max = -inf;
152
+ #pragma unroll
153
+ for (int j = -1; j <= 1; j++) {
154
+ const int lh2 = lh2_ + j;
155
+ if (lh2 < 0 || lh2 >= LH2) continue;
156
+ #pragma unroll
157
+ for (int i = -1; i <= 1; i++) {
158
+ const int lw2 = lw2_ + i;
159
+ if (lw2 < 0 || lw2 >= LW2) continue;
160
+ float l = lower[lh1][lw1][lh2][lw2];
161
+ lower_max = MAX(lower_max, l);
162
+ }}
163
+ if (lower_max == -inf) continue;
164
+
165
+ res += weight * lower_max;
166
+ nrm += weight;
167
+ }
168
+
169
+ // normalize output
170
+ nrm = sumw * (nrm < sumw ? powf(nrm/sumw, norm) : 1);
171
+ upper[uh1][uw1][uh2][uw2] = (nrm ? res / nrm : 0);
172
+ if (uh2 == 1 && uw2 == 1)
173
+ new_weights[uh1*UW1 + uw1] = sumw;
174
+ }
175
+
176
+ torch::Tensor forward_pool_agg_cuda( int level, float norm, const torch::Tensor lower,
177
+ const at::optional<at::Tensor> weights, torch::Tensor upper ) {
178
+ CHECK_CUDA(lower);
179
+ CHECK_CUDA(upper);
180
+ if (weights) CHECK_CUDA(weights.value());
181
+
182
+ const auto LH1 = lower.size(0);
183
+ const auto LW1 = lower.size(1);
184
+ const auto LH2 = lower.size(2);
185
+ const auto LW2 = lower.size(3);
186
+ const auto UH1 = upper.size(0);
187
+ const auto UW1 = upper.size(1);
188
+ const auto UH2 = upper.size(2);
189
+ const auto UW2 = upper.size(3);
190
+ TORCH_CHECK( UH1 == LH1 + int(level==1) && UW1 == LW1 + int(level==1), "inconsistent lower and upper shapes" );
191
+ TORCH_CHECK( UH2 == (LH2-1)/2+1 && UW2 == (LW2-1)/2+1, "lower level should be twice as big" );
192
+
193
+ const int gap_left = (level >= 2) ? 1 << (level-2) : 0; // 0, 1, 2, 4, ...
194
+ const int gap_right= 1 << MAX(0, level-2); // 1, 1, 2, 4, ...
195
+
196
+ const int MAX_THREADS = 512; // faster than 1024 (higher SM occupancy)
197
+ const int THREADS_PER_BLOCK = MAX_THREADS;
198
+ const int N_BLOCKS = (UH1*UW1*UH2*UW2 + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
199
+
200
+ torch::Tensor new_weights = torch::zeros({UH1, UW1}, upper.options().dtype(torch::kFloat));
201
+
202
+ // one block for each layer, one thread per local-max
203
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(lower.type(), "forward_pool_agg_cuda", ([&] {
204
+ forward_pool_agg_cuda_kernel<<<N_BLOCKS, THREADS_PER_BLOCK>>>(
205
+ LH1, LW1, LH2, LW2,
206
+ // UH1, UW1, UH2, UW2,
207
+ gap_left, gap_right, norm,
208
+ lower.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
209
+ upper.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
210
+ weights ? weights->data<float>() : nullptr, new_weights.data<float>() );
211
+ }));
212
+ return new_weights;
213
+ }
214
+
215
+ __device__ inline int in(int lower, int var, int upper) {
216
+ return lower <= var && var < upper;
217
+ }
218
+ __device__ inline int sl(bool b) {
219
+ return b ? 1 : -1;
220
+ }
221
+
222
+ __device__ short atomicMaxShort(short* address, short val) {
223
+ unsigned int *base_address = (unsigned int *)((size_t)address & ~3); // multiple of 4
224
+
225
+ unsigned int order_from[] = {0x0010, 0x0032}; // either bytes[0:2] or bytes[2:4]
226
+ unsigned int from = order_from[((size_t)address & 3) / 2];
227
+
228
+ unsigned int order_back[] = {0x3254, 0x5410}; // right-to-left
229
+ unsigned int back = order_back[((size_t)address & 3) / 2];
230
+ unsigned int old, assumed, max_, new_;
231
+
232
+ old = *base_address;
233
+ do {
234
+ assumed = old;
235
+ max_ = max(val, (short)__byte_perm(old, 0, from)); // extract word
236
+ new_ = __byte_perm(old, max_, back); // replace word
237
+ old = atomicCAS(base_address, assumed, new_);
238
+ } while (assumed != old);
239
+ return old;
240
+ }
241
+
242
+ template <typename scalar_t>
243
+ __device__ inline void TplAtomicMax_block( scalar_t* before, scalar_t after ) { assert(!"atomicMax not implemented for this dtype"); }
244
+ template <>
245
+ __device__ inline void TplAtomicMax_block( at::Half* before, at::Half after ) { atomicMaxShort( (int16_t*)before, *(int16_t*)&after ); }
246
+ template <>
247
+ __device__ inline void TplAtomicMax_block( float* before, float after ) { atomicMax_block( (int32_t*)before, *(int32_t*)&after ); }
248
+
249
+ template <typename scalar_t>
250
+ __global__ void backward_agg_unpool_cuda_kernel(
251
+ const int UH1, const int UW1,
252
+ const int UH2, const int UW2,
253
+ const int LH2, const int LW2,
254
+ const int gap_left, const int gap_right,
255
+ const torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> upper,
256
+ torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> lower ) {
257
+
258
+ /* Each block is going to take care of a single layer, i.e. lower[:,:,0::2,0::2].
259
+ the first thread is allocating some global memory and then frees it later.
260
+ */
261
+ // const int LH1 = gridDim.x;
262
+ // const int LW1 = gridDim.y;
263
+ const int lh1 = blockIdx.y;
264
+ const int lw1 = blockIdx.x;
265
+ const int UHW2 = UH2 * UW2; // upper layer size
266
+
267
+ __shared__ float* _shared_addr;
268
+ if (threadIdx.x == 0)
269
+ do{ _shared_addr = new float [2*UHW2]; } // for each upper place, we have (best, bestp)
270
+ while(!_shared_addr); // waiting for memory to be available...
271
+ __syncthreads();
272
+
273
+ float * layer_best = _shared_addr;
274
+ int * layer_bestp = (int*)(_shared_addr+1); //UHW);
275
+ assert( layer_best );
276
+
277
+ /* First pass: we recover the position and values of all local maxima in the layer
278
+ */
279
+ for (int idx = threadIdx.x; idx < UHW2; idx += blockDim.x) {
280
+ const int ux = idx % UW2;
281
+ const int uy = idx / UW2;
282
+ const int lx = 2*ux; // lower pos from upper pos
283
+ const int ly = 2*uy;
284
+
285
+ // argmax my local minima
286
+ float best = -inf;
287
+ int bestp = 0;
288
+ #pragma unroll
289
+ for (int j_= -1; j_<= 1; j_++) {
290
+ const int j = ly + j_;
291
+ if (j < 0 || j >= LH2) continue;
292
+ #pragma unroll
293
+ for (int i_= -1; i_<= 1; i_++) {
294
+ const int i = lx + i_;
295
+ if (i < 0 || i >= LW2) continue;
296
+ float cur = lower[lh1][lw1][j][i];
297
+ if (cur > best) { best = cur; bestp = j*LW2+i; }
298
+ }}
299
+ layer_best[2*idx] = best;
300
+ layer_bestp[2*idx] = bestp;
301
+ }
302
+
303
+ __syncthreads();
304
+
305
+ /* Second pass: we update the local maxima according to the upper layer
306
+ */
307
+ for (int idx = threadIdx.x; idx < UHW2; idx += blockDim.x) {
308
+ const int ux = idx % UW2;
309
+ const int uy = idx / UW2;
310
+
311
+ // max-pool the additional value from the upper layer
312
+ scalar_t add = 0;
313
+ for (int v = -gap_left; v <= gap_right; v += gap_right+gap_left) {
314
+ for (int u = -gap_left; u <= gap_right; u += gap_right+gap_left) {
315
+ const int uh1 = lh1 + v, uw1 = lw1 + u;
316
+ const int uh2 = uy+sl(v>0), uw2 = ux+sl(u>0);
317
+ if (in(0, uh1, UH1) && in(0, uw1, UW1) && in(0, uh2, UH2) && in(0, uw2, UW2))
318
+ add = MAX(add, upper[uh1][uw1][uh2][uw2]);
319
+ }}
320
+
321
+ // grab local maxima
322
+ float best = layer_best[2*idx];
323
+ int bestp = layer_bestp[2*idx];
324
+ const int lx = bestp % LW2;
325
+ const int ly = bestp / LW2;
326
+
327
+ // printf("UH=%d,UW=%d: uy=%d,ux=%d --> best=%g at ly=%d,lx=%d\n", UH,UW, uy,ux, best, ly,lx);
328
+ scalar_t* before = & lower[lh1][lw1][ly][lx];
329
+ scalar_t after = best + add;
330
+ TplAtomicMax_block<scalar_t>( before, after );
331
+ }
332
+
333
+ __syncthreads();
334
+
335
+ if (threadIdx.x == 0)
336
+ delete _shared_addr;
337
+ }
338
+
339
+ void backward_agg_unpool_cuda( int level, const torch::Tensor upper, torch::Tensor lower, bool exclude_borders ) {
340
+ CHECK_CUDA(lower);
341
+ CHECK_CUDA(upper);
342
+
343
+ const auto UH1 = upper.size(0);
344
+ const auto UW1 = upper.size(1);
345
+ const auto UH2 = upper.size(2);
346
+ const auto UW2 = upper.size(3);
347
+ const auto LH1 = lower.size(0);
348
+ const auto LW1 = lower.size(1);
349
+ const auto LH2 = lower.size(2);
350
+ const auto LW2 = lower.size(3);
351
+ TORCH_CHECK( UH1 == LH1 + int(level==1) && UW1 == LW1 + int(level==1), "inconsistent lower and upper shapes" );
352
+ const int xb = exclude_borders; // local_argmax cannot reach the bottom and right borders
353
+
354
+ const int gap_left = (level >= 2) ? 1 << (level-2) : 0; // 0, 1, 2, 4, ...
355
+ const int gap_right= 1 << MAX(0, level-2); // 1, 1, 2, 4, ...
356
+
357
+ const int64_t MAX_THREADS = 1024;
358
+ const int64_t THREADS_PER_LAYER = MIN(UH2*UW2, MAX_THREADS);
359
+
360
+ // one block for each layer, one thread per local-max
361
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(upper.type(), "backward_agg_unpool_cuda", ([&] {
362
+ backward_agg_unpool_cuda_kernel<<<dim3(LW1,LH1), THREADS_PER_LAYER>>>(
363
+ UH1, UW1, UH2, UW2, LH2-xb, LW2-xb,
364
+ gap_left, gap_right,
365
+ upper.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
366
+ lower.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>());
367
+ }));
368
+ CHECK_KERNEL();
369
+ }
370
+
371
+ template <typename scalar_t>
372
+ __global__ void max_pool3d_cuda_kernel(
373
+ const int BS, const int NC, const int IH, const int IW, const int OH, const int OW,
374
+ const int ks, const int stride,
375
+ const torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> tensor,
376
+ torch::PackedTensorAccessor64<scalar_t,3,torch::RestrictPtrTraits> maxima,
377
+ torch::PackedTensorAccessor64<int64_t, 3,torch::RestrictPtrTraits> indices ) {
378
+
379
+ // each thread takes care of one output
380
+ int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
381
+ const int x = idx % OW; idx /= OW;
382
+ const int y = idx % OH; idx /= OH;
383
+ const int b = idx;
384
+ if (b >= BS) return;
385
+
386
+ float best = -inf;
387
+ int64_t best_pos = 0;
388
+ for (int64_t c = 0; c < NC; c++) {
389
+ for (int j = stride*y; j < stride*y+ks; j++) {
390
+ for (int i = stride*x; i < stride*x+ks; i++) {
391
+ // assert( b < BS and c < NC and j < IH and i < IW );
392
+ float cur = tensor[b][c][j][i];
393
+ if (cur > best) {best = cur; best_pos = (c*IH + j)*IW+ i; }
394
+ }}}
395
+
396
+ // assert( b < BS and y < OH and x < OW );
397
+ maxima [b][y][x] = best;
398
+ indices[b][y][x] = best_pos;
399
+ }
400
+
401
+ void max_pool3d_cuda( const torch::Tensor tensor, const int kernel_size, const int stride,
402
+ torch::Tensor maxima, torch::Tensor indices ) {
403
+ CHECK_CUDA(tensor);
404
+ TORCH_CHECK(tensor.dim() == 4, "tensor should be 4-dimensional: BxCxHxW");
405
+ const int BS = tensor.size(0);
406
+ const int NC = tensor.size(1);
407
+ const int IH = tensor.size(2); // input height
408
+ const int IW = tensor.size(3); // input width
409
+
410
+ // output size
411
+ TORCH_CHECK( maxima.sizes() == indices.sizes(), "maxima and indices should have the same shape" );
412
+ TORCH_CHECK( BS == maxima.size(0), "bad batch size" );
413
+ const int OH = maxima.size(1);
414
+ const int OW = maxima.size(2);
415
+
416
+ const int64_t THREADS_PER_LAYER = 512;
417
+ const int64_t N_BLOCKS = (BS*OH*OW + THREADS_PER_LAYER-1) / THREADS_PER_LAYER;
418
+
419
+ // one block for each layer, one thread per local-max
420
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor.type(), "max_pool3d_cuda", ([&] {
421
+ max_pool3d_cuda_kernel<<<N_BLOCKS, THREADS_PER_LAYER>>>(
422
+ BS, NC, IH, IW, OH, OW, kernel_size, stride,
423
+ tensor. packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
424
+ maxima. packed_accessor64<scalar_t,3,torch::RestrictPtrTraits>(),
425
+ indices.packed_accessor64<int64_t,3,torch::RestrictPtrTraits>());
426
+ }));
427
+ }
428
+
429
+
430
+ __device__ inline float ptdot( const float* m, float x, float y ) {
431
+ return x*m[0] + y*m[1] + m[2];
432
+ }
433
+
434
+ __device__ inline float sqr(float v) {
435
+ return v*v;
436
+ }
437
+
438
+
439
+ __global__ void merge_corres_cuda_kernel(
440
+ const int OH, const int OW, const int OZ, const int IH, const int IW,
441
+ const float dmax2, int offset, const float* inv_rot, const int all_step,
442
+ const torch::PackedTensorAccessor32<float,3,torch::RestrictPtrTraits> corres_a,
443
+ torch::PackedTensorAccessor32<float,3,torch::RestrictPtrTraits> all_corres_a ) {
444
+
445
+ // each thread takes care of one output
446
+ int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
447
+ const int i = idx % OW; idx /= OW;
448
+ const int j = idx;
449
+ if (j >= OH) return;
450
+
451
+ const float tol2 = 2*2; // squared
452
+ auto all_cor = all_corres_a[j][i];
453
+
454
+ // center of the bin in the reference frame
455
+ float x = i*all_step + all_step/2;
456
+ float y = j*all_step + all_step/2;
457
+
458
+ // center of the bin on the rescaled+rotated image
459
+ float xr = ptdot( inv_rot + 0, x, y );
460
+ float yr = ptdot( inv_rot + 3, x, y );
461
+
462
+ // iterate on the nearby bins
463
+ int xb = (int)(0.5+ xr/4); // rescaled+rotated desc always has step 4
464
+ int yb = (int)(0.5+ yr/4);
465
+
466
+ float best = dmax2;
467
+ #pragma unroll
468
+ for (int _v = -1; _v <= 1; _v++) {
469
+ #pragma unroll
470
+ for (int _u = -1; _u <= 1; _u++) {
471
+ const int v = yb+_v, u = xb+_u;
472
+ if (!(in(0, v, IH) && in(0, u, IW))) continue;
473
+ auto cor = corres_a[v][u];
474
+ float d = sqr(cor[offset]-x) + sqr(cor[offset+1]-y);
475
+ if (d < best) best = d;
476
+ }}
477
+
478
+ #pragma unroll
479
+ for (int _v = -1; _v <= 1; _v++) {
480
+ #pragma unroll
481
+ for (int _u = -1; _u <= 1; _u++) {
482
+ const int v = yb+_v, u = xb+_u;
483
+ if (!(in(0, v, IH) && in(0, u, IW))) continue;
484
+ auto cor = corres_a[v][u];
485
+ float d = sqr(cor[offset]-x) + sqr(cor[offset+1]-y);
486
+ if (d <= tol2*best) { // spatially close
487
+ // merge correspondence if score is better than actual
488
+ if (cor[4] > all_cor[4])
489
+ for (int k = 0; k < OZ; k++) all_cor[k] = cor[k];
490
+ }
491
+ }}
492
+ }
493
+
494
+ void merge_corres_cuda( const torch::Tensor corres, const int offset, const torch::Tensor _inv_rot,
495
+ const float dmax, torch::Tensor all_corres, const int all_step ) {
496
+ CHECK_CUDA( corres );
497
+ CHECK_CUDA( all_corres );
498
+ CHECK_CUDA( _inv_rot );
499
+ TORCH_CHECK(_inv_rot.is_contiguous(), "inv_rot should be contiguous" );
500
+
501
+ const int IH = corres.size(0);
502
+ const int IW = corres.size(1);
503
+ const int IZ = corres.size(2);
504
+ const int OH = all_corres.size(0);
505
+ const int OW = all_corres.size(1);
506
+ const int OZ = all_corres.size(2);
507
+ TORCH_CHECK( IZ == OZ, "corres and all_corres should have the same shape[2]" );
508
+
509
+ const int THREADS_PER_LAYER = 512;
510
+ const int N_BLOCKS = (OH * OW + THREADS_PER_LAYER-1) / THREADS_PER_LAYER;
511
+
512
+ merge_corres_cuda_kernel<<<N_BLOCKS, THREADS_PER_LAYER>>>(
513
+ OH, OW, OZ, IH, IW, dmax*dmax, offset, _inv_rot.data_ptr<float>(), all_step,
514
+ corres.packed_accessor32<float,3,torch::RestrictPtrTraits>(),
515
+ all_corres.packed_accessor32<float,3,torch::RestrictPtrTraits>());
516
+ CHECK_KERNEL();
517
+ }
518
+
519
+
520
+ template <typename scalar_t>
521
+ __global__ void mask_correlations_radial_cuda_kernel(
522
+ float radius, const float alpha,
523
+ const torch::PackedTensorAccessor32<float,3,torch::RestrictPtrTraits> targets,
524
+ torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> corr ) {
525
+
526
+ #define H1 ((int)corr.size(0))
527
+ #define W1 ((int)corr.size(1))
528
+ #define H2 ((int)corr.size(2))
529
+ #define W2 ((int)corr.size(3))
530
+
531
+ // each block takes care of one layer corr[j,i,:,:]
532
+ const int j = blockIdx.x / W1;
533
+ const int i = blockIdx.x % W1;
534
+ if (j >= H1) return;
535
+
536
+ // read the target center
537
+ const float cx = targets[j][i][0];
538
+ const float cy = targets[j][i][1];
539
+ if (cx != cx || cy != cy) return; // undefined center
540
+ radius *= radius; // squared
541
+ const float alpha_out = (alpha > 1 ? 1 : alpha);
542
+ const float alpha_in = (alpha < 1 ? 1 : alpha);
543
+
544
+ for (int idx = threadIdx.x; idx < H2*W2; idx += blockDim.x) {
545
+ const int v = idx / W2;
546
+ const int u = idx % W2;
547
+
548
+ // compute weighting
549
+ float dis2 = sqr(u - cx) + sqr(v - cy);
550
+ float mul = alpha_in;
551
+ if (dis2 > radius)
552
+ mul = 1 - alpha_out*(1 - radius / dis2);
553
+
554
+ corr[j][i][v][u] *= mul;
555
+ }
556
+ }
557
+
558
+ void mask_correlations_radial_cuda( torch::Tensor corr, const torch::Tensor targets,
559
+ const float radius, const float alpha) {
560
+ CHECK_CUDA( corr );
561
+ CHECK_CUDA( targets );
562
+
563
+ const int THREADS_PER_LAYER = 512;
564
+ const int N_BLOCKS = H1*W1;
565
+
566
+ #undef H1
567
+ #undef W1
568
+ #undef H2
569
+ #undef W2
570
+
571
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(corr.type(), "mask_correlations_radial_cuda", ([&] {
572
+ mask_correlations_radial_cuda_kernel<<<N_BLOCKS, THREADS_PER_LAYER>>>(
573
+ radius, alpha,
574
+ targets.packed_accessor32<float,3,torch::RestrictPtrTraits>(),
575
+ corr.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>());
576
+ }));
577
+ CHECK_KERNEL();
578
+ }
core/cuda_deepm/setup.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from setuptools import setup
6
+ from torch import cuda
7
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
8
+
9
+ # if you want to compile for all possible CUDA architectures
10
+ all_cuda_archs = [] #cuda.get_gencode_flags().replace('compute=','arch=').split()
11
+
12
+ setup(
13
+ name='cuda_deepm',
14
+ ext_modules = [
15
+ CUDAExtension(
16
+ name = 'cuda_deepm',
17
+ sources = ["func.cpp", "kernels.cu"],
18
+ extra_compile_args = dict(nvcc=['-O2']+all_cuda_archs, cxx=['-O2'])
19
+ )
20
+ ],
21
+ cmdclass = {
22
+ 'build_ext': BuildExtension
23
+ })
24
+
core/functional.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+
11
+ def affmul( aff, vecs ):
12
+ """ affine multiplication:
13
+ computes aff @ vecs.T """
14
+ if aff is None: return vecs
15
+ if isinstance(aff, (tuple,list)) or aff.ndim==3:
16
+ assert len(aff) == 2
17
+ assert 4 <= vecs.shape[-1], bb()
18
+ vecs = vecs.clone() if isinstance(vecs, torch.Tensor) else vecs.copy()
19
+ vecs[...,0:2] = affmul(aff[0], vecs[...,0:2])
20
+ vecs[...,2:4] = affmul(aff[1], vecs[...,2:4])
21
+ return vecs
22
+ else:
23
+ assert vecs.shape[-1] == 2, bb()
24
+ assert aff.shape == (2,3) or (aff.shape==(3,3) and
25
+ aff[2,0] == aff[2,1] == 0 and aff[2,2] == 1), bb()
26
+ return (vecs @ aff[:2,:2].T) + aff[:2,2]
27
+
28
+
29
+ def imresize( img, max_size, mode='area' ):
30
+ # trf: cur_pix --> old_pix
31
+ img, trf = img if isinstance(img,tuple) else (img, torch.eye(3,device=img.device))
32
+
33
+ shape = img.shape[-2:]
34
+ if max_size > 0 and max(shape) > max_size:
35
+ new_shape = tuple(i * max_size // max(shape) for i in shape)
36
+ img = F.interpolate( img[None].float(), size=new_shape, mode=mode )[0]
37
+ img.clamp_(min=0, max=255)
38
+ sca = torch.diag(torch.tensor((shape[0]/new_shape[0],shape[1]/new_shape[1],1), device=img.device))
39
+ img = img.byte()
40
+ trf = trf @ sca # undo sca first
41
+
42
+ return img, trf
43
+
44
+
45
+ def rotate_img( img, angle, crop=False ):
46
+ if angle in (0, 90, 180, 270):
47
+ return rotate_img_90(img,angle)
48
+
49
+ img, trf = img
50
+ assert trf.shape == (3,3)
51
+
52
+ def centered_rotation(rotation, shape, **device):
53
+ # rotation matrix
54
+ # pt_in_original_image = rot * pt_in_rotated_image
55
+ angle = rotation * np.pi / 180
56
+ c, s = np.cos(angle), np.sin(angle)
57
+ rot = torch.tensor([(c, -s, 0), (s, c, 0), (0, 0, 1)], dtype=torch.float32, **device)
58
+
59
+ # determine center of rotation before
60
+ H, W = shape
61
+ c_before = torch.tensor((W,H), **device) / 2
62
+ if crop:
63
+ c_after = c_before
64
+ rot_size = (W,H)
65
+ else:
66
+ # enlarge image to fit everything
67
+ corners = torch.tensor([(0, W, W, 0), (0, 0, H, H)], dtype=torch.float32, **device)
68
+ corners = affmul(rot, corners.T).T
69
+ rot_size = (corners.max(dim=1).values - corners.min(dim=1).values + 0.5).int()
70
+ rot_size = (rot_size // 4) * 4 # legacy
71
+ c_after = rot_size / 2
72
+
73
+ rot[:2,2] = c_before - affmul(rot, c_after) # fix translation
74
+ return rot, tuple(rot_size)[::-1]
75
+
76
+ C, H, W = img.shape
77
+ rot, (OH, OW) = centered_rotation(angle, (H,W), device=img.device)
78
+
79
+ # pt_in_original_image = rot * pt_in_rotated_image
80
+ # but pytorch works in [-1,1] coordinates... annoying
81
+ # pt_in_original_1_1 = orig_px_to_1_1 * rot * rotated_1_1_to_px * pt_in_rotated_1_1
82
+ _1_1_to_px = lambda W,H: torch.tensor(((W/2, 0, W/2), (0, H/2, H/2), (0, 0, 1)), device=img.device)
83
+ theta = torch.inverse(_1_1_to_px(W-1,H-1)) @ rot @ _1_1_to_px(OW-1,OH-1)
84
+
85
+ grid = F.affine_grid(theta[None,:2], (1, C, OH, OW), align_corners=True)
86
+ res = F.grid_sample(img[None].float(), grid, align_corners=True).to(dtype=img.dtype)[0]
87
+ return res, trf @ rot
88
+
89
+
90
+
91
+ def rotate_img_90( img, angle ):
92
+ """ Rotate an image by a multiple of 90 degrees using simple transpose and flip ops.
93
+ img = tuple( image, existing_trf )
94
+ existing_trf: current --> old
95
+ """
96
+ angle = angle % 360
97
+ assert angle in (0, 90, 180, 270), 'cannot handle rotation other than multiple of 90 degrees'
98
+ img, trf = img
99
+ assert trf.shape == (3,3)
100
+
101
+ if isinstance(img, np.ndarray):
102
+ assert img.ndim == 3 and 1 <= img.shape[2] <= 3
103
+ new, x, y = np.float32, 1, 0
104
+ flip = lambda i,d: np.flip(i,axis=d)
105
+ elif isinstance(img, torch.Tensor):
106
+ assert img.ndim == 3 and 1 <= img.shape[0] <= 3
107
+ new, x, y = trf.new, -1, -2
108
+ flip = lambda i,d: i.flip(dims=[d])
109
+ H, W = img.shape[y], img.shape[x]
110
+
111
+ if angle == 90:
112
+ # point 0,0 --> (0, H-1); W-1,0 --> 0,0
113
+ img = flip(img.swapaxes(x,y),y)
114
+ trf = trf @ new([[0,-1,W-1],[1,0,0],[0,0,1]]) # inverse transform: new --> current
115
+ if angle == 180:
116
+ # point 0,0 --> (W-1, H-1)
117
+ img = flip(flip(img,x),y)
118
+ trf = trf @ new([[-1,0,W-1],[0,-1,H-1],[0,0,1]]) # inverse transform: new --> current
119
+ if angle == 270:
120
+ # point 0,0 --> (H-1, 0); 0,H-1 --> 0,0
121
+ img = flip(img.swapaxes(x,y),x)
122
+ trf = trf @ new([[0,1,0],[-1,0,H-1],[0,0,1]]) # inverse transform: new --> current
123
+ return img, trf
124
+
125
+
126
+ def encode_scale_rot(scale, rot):
127
+ s = np.int32(np.rint(np.log(scale) / (0.5*np.log(2))))
128
+ r = np.int32(np.rint(((-rot) % 360) / 45)) % 8
129
+ return 8*s + (r%8)
130
+
131
+ def decode_scale_rot( code ):
132
+ s = code // 8
133
+ r = (code % 8)
134
+ return 2 ** (s/2), -((45 * r + 180) % 360 - 180)
135
+
136
+
137
+ def normalized_corr(patches, img, padding='ncc', extra_patch=False, ret_norms=False):
138
+ assert patches.ndim == 4, 'patches shape must be (H*W, C, K, K)'
139
+ P, C, K, K = patches.shape
140
+ assert img.ndim == 3 and img.shape[0] == C, 'img shape must be (C, W, H)'
141
+ eps = torch.finfo(patches.dtype).tiny
142
+
143
+ # normalize on patches side
144
+ norms = patches.view(P,-1).norm(dim=-1)
145
+ patches = patches / norms[:,None,None,None].clamp(min=eps)
146
+
147
+ # convolve normalized patches on unnormalized image
148
+ ninth = 0
149
+ if padding == 'ninth':
150
+ ninth = img[:,-1].mean() # ninth dimension
151
+ img = F.pad(img[None], (K//2,K//2)*2, mode='constant', value=ninth)[0]
152
+
153
+ corr = F.conv2d(img[None], patches, padding=0, bias=None)[0]
154
+
155
+ # normalize on img's side
156
+ ones = patches.new_ones((1, C, K, K))
157
+ local_norm = torch.sqrt(F.conv2d(img[None]**2, ones))[0]
158
+ corr /= local_norm
159
+
160
+ # normalize on patches' side (image borders)
161
+ if padding == 'ncc':
162
+ local_norm = torch.sqrt(F.conv2d(ones, patches**2, padding=2))[0]
163
+ local_norm.clamp_(min=eps)
164
+ for j in range(-2, 3):
165
+ for i in range(-2,3):
166
+ if i == j == 2: continue # normal case is already normalized
167
+ if i == 2: i = slice(2,-2)
168
+ if j == 2: j = slice(2,-2)
169
+ corr[:,j,i] /= local_norm[:,j,i]
170
+
171
+ return (corr, norms) if ret_norms else corr
172
+
173
+
174
+ def true_corr_shape( corr_shape, level ):
175
+ H1, W1, H2, W2 = corr_shape[-4:]
176
+ if level > 0: # recover true size
177
+ H1, W1 = H1-1, W1-1
178
+ return corr_shape[:-4] + (H1, W1, H2, W2)
179
+
180
+ def children(level, H1, W1, H2, W2):
181
+ """ level: parent level (> 1) """
182
+ gap = 2**(level-2)
183
+ # @ level 1: gap=0.5 (parent at x=1 has children at x=[0.5, 1.5])
184
+ # @ level 2: gap=1 (parent at x=1 has children at x=[0, 2])
185
+ # @ level 3: gap=2 (parent at x=2 has children at x=[0, 4])
186
+ # etc.
187
+
188
+ def ravel_child(x, y):
189
+ # x,y is he center of the child patch
190
+ inside = (0 <= x <= W1) and (0 <= y <= H1)
191
+ if gap < 1:
192
+ assert x % 1 == y % 1 == 0.5, bb()
193
+ return int((x-0.5) + (y-0.5) * W1) if inside else -1
194
+ else:
195
+ assert x % 1 == y % 1 == 0, bb()
196
+ return int(x + y * (W1+1)) if inside else -1
197
+
198
+ # 4 children for each parent patch (top-left, top-right, bot-left, bot-right, -1 = None)
199
+ parents = []
200
+ for h in range(H1+1):
201
+ for w in range(W1+1):
202
+ # enumerate the 4 children for this patch
203
+ children = [ravel_child(w + gap*tx, h + gap*ty) for ty in (-1,1) for tx in (-1,1)]
204
+ parents.append(children)
205
+
206
+ return torch.tensor(parents, dtype=torch.int64)
207
+
208
+
209
+ def sparse_conv(level, corr, weights=None, reverse=False, norm=0.9):
210
+ H1, W1, H2, W2 = true_corr_shape(corr.shape, level-1 + reverse)
211
+ parents = children(level, H1, W1, H2, W2).to(corr.device)
212
+ n_parents = len(parents)
213
+
214
+ # perform the sparse convolution 'manually'
215
+ # since sparse convolutions are not implemented in pytorch currently
216
+ corr = corr.view(-1, *corr.shape[-2:])
217
+ if not reverse:
218
+ res = corr.new_zeros((n_parents+1,)+corr.shape[-2:]) # last one = garbage channel
219
+ nrm = corr.new_full((n_parents+1,3,3), 1e-8)
220
+ ones = nrm.new_ones((len(corr),1,1))
221
+ ex = 1
222
+ if weights is not None:
223
+ weights = weights.view(len(corr),1,1)
224
+ corr *= weights # apply weights to correlation maps without increasing memory footprint
225
+ ones *= weights
226
+ else:
227
+ assert corr._base is not None and corr._base.shape[0] == n_parents+1
228
+ corr._base[-1] = 0 # reset garbage layer
229
+ ex = 1 if level > 1 else 0
230
+ n_children = (H1+ex) * (W1+ex)
231
+ res = corr.new_zeros((n_children,)+corr.shape[-2:])
232
+
233
+ sl = lambda v: slice(0,-1 or None) if v < 0 else slice(1,None)
234
+ c = 0
235
+ for y in (-1, 1):
236
+ for x in (-1, 1):
237
+ src_layers = parents[:,c]; c+= 1
238
+ # we want to do: res += corr[src_layers] (for all children != -1)
239
+ # but we only have 'res.index_add_()' <==> res[tgt_layers] += corr
240
+ tgt_layers = inverse_mapping(src_layers, max_elem=len(corr), default=n_parents)[:-1]
241
+
242
+ if not reverse:
243
+ # All of corr's channels MUST be utilized. for level>1, this doesn't hold,
244
+ # so we'll send them to a garbage channel ==> res[n_parents]
245
+ sel = good_slice( tgt_layers < n_parents )
246
+
247
+ res[:,sl(-y),sl(-x)].index_add_(0, tgt_layers[sel], corr[sel,sl(y),sl(x)])
248
+ nrm[:,sl(-y),sl(-x)].index_add_(0, tgt_layers[sel], ones[sel].expand(-1,2,2))
249
+ else:
250
+ ''' parent=199=11*17+12 @ (x=48, y=44) at level=1
251
+ |-- child=171 @ (x=46,y=42) at level0
252
+ |-- child=172 @ (x=50,y=42) at level0
253
+ |-- child=187 @ (x=46,y=46) at level0
254
+ |-- child=188 @ (x=50,y=46) at level0
255
+ '''
256
+ out = res[:,sl(y),sl(x)]
257
+ sel = tgt_layers[:n_children]
258
+ torch.maximum(out, corr._base[sel,sl(-y),sl(-x)], out=out)
259
+
260
+ if not reverse:
261
+ if weights is not None: corr /= weights.clamp(min=1e-12) # cancel weights
262
+ weights = norm_borders(res, nrm, norm=norm)[:-1]
263
+ res = res[:-1] # remove garbage channel
264
+ res = res.view(H1+ex, W1+ex, *res.shape[-2:])
265
+ return res if reverse else (res, weights)
266
+
267
+ def norm_borders( res, nrm, norm=0.9 ):
268
+ """ apply some border normalization, modulated by `norm`
269
+ - if norm=0: no normalization at all
270
+ - if norm=1: full normalization
271
+ Formula: nrm = k * (nrm/k)**p = k**(1-p) * nrm**p,
272
+ with k=nrm[:,1,1] and p=norm
273
+ """
274
+ new_weights = nrm[...,1,1].clone()
275
+ nrm = (nrm[...,1:2,1:2] ** (1-norm)) * (nrm ** norm)
276
+ # assert not torch.isnan(nrm).any()
277
+
278
+ # normalize results on the borders
279
+ res[...,0 ,0 ] /= nrm[...,0 ,0 ]
280
+ res[...,0 ,1:-1] /= nrm[...,0 ,1:2]
281
+ res[...,0 , -1] /= nrm[...,0 ,2 ]
282
+ res[...,1:-1,0 ] /= nrm[...,1:2,0 ]
283
+ res[...,1:-1,1:-1] /= nrm[...,1:2,1:2]
284
+ res[...,1:-1, -1] /= nrm[...,1:2,2 ]
285
+ res[..., -1,0 ] /= nrm[...,2 ,0 ]
286
+ res[..., -1,1:-1] /= nrm[...,2 ,1:2]
287
+ res[..., -1, -1] /= nrm[...,2 ,2 ]
288
+ return new_weights
289
+
290
+
291
+ def inverse_mapping( map, max_elem=None, default=None):
292
+ """ given a mapping {i:j} we output {j:i}
293
+ (the mapping is a torch array)
294
+ """
295
+ assert isinstance(map, torch.Tensor) and map.ndim == 1
296
+ if max_elem is None: max_elem = map.max()
297
+ if default is None:
298
+ index = torch.empty(max_elem+1, dtype=torch.int64, device=map.device) # same size as corr, last elem == garbage
299
+ else:
300
+ index = torch.full((max_elem+1,), default, dtype=torch.int64, device=map.device) # same size as corr, last elem == garbage
301
+ index[map] = torch.arange(len(map), device=map.device)
302
+ return index
303
+
304
+
305
+ def good_slice( nonzero ):
306
+ good = nonzero.nonzero().ravel()
307
+ return slice(good.min().item(), good.max().item()+1)
308
+
309
+
310
+ def max_unpool(upper, lower, exclude_border=True):
311
+ # re-compute max-pool indices
312
+ if exclude_border:
313
+ # apparently, we cannot unpool on the bottom and right borders in legacy code (local_argmax with ex=1)
314
+ _, pos = F.max_pool2d(lower[:,:,:-1,:-1], 3, padding=1, stride=2, return_indices=True, ceil_mode=True)
315
+ W1 = lower.shape[-1]
316
+ pos = (pos//(W1-1))*W1 + (pos%(W1-1)) # fix the shortening
317
+ else:
318
+ _, pos = F.max_pool2d(lower, 3, padding=1, stride=2, return_indices=True)
319
+
320
+ # because there are potential collisions between overlapping 3x3 cells,
321
+ # that pytorch does not handle, we unpool in 4 successive non-overlapping steps.
322
+ for i in range(2):
323
+ for j in range(2):
324
+ # stride=0 instead of 1 because pytorch does some size checking, this is a hack
325
+ tmp = F.max_unpool2d(upper[:,:,i::2,j::2], pos[:,:,i::2,j::2], kernel_size=3, padding=0, stride=4, output_size=lower.shape[-2:])
326
+ if i == j == 0:
327
+ res = tmp
328
+ else:
329
+ torch.maximum(res, tmp, out=res)
330
+
331
+ # add scores to existing lower correlation map
332
+ lower += res
333
+ return lower
334
+
335
+
336
+ def mgrid( shape, **kw ):
337
+ """ Returns in (x, y) order (contrary to numpy which is (y,x) """
338
+ if isinstance(shape, torch.Tensor): shape = shape.shape
339
+ res = torch.meshgrid(*[torch.arange(n, dtype=torch.float32, **kw) for n in shape], indexing='ij')
340
+ return torch.stack(res[::-1], dim=-1).view(-1,2)
341
+
342
+
343
+ def check_corres( corres, step, rot=None ):
344
+ H, W, two = corres.shape
345
+ assert two == 2
346
+ if isinstance(corres, np.ndarray):
347
+ corres = torch.from_numpy(corres)
348
+ if rot is not None:
349
+ corres = affmul(rot, corres)
350
+ gt = mgrid(corres.shape[:2]).view(H,W,2)
351
+ assert ((gt - corres // step).abs() <= 2).float().mean() > 0.99, bb()
352
+
353
+
354
+ def best_correspondences(corr):
355
+ """ All positions are returned as x1, y1, x2, y2
356
+ """
357
+ if isinstance(corr, tuple): return corr # for legacy
358
+ H1, W1, H2, W2 = corr.shape
359
+ fix1 = lambda arr: 4*arr+2 # center of cells in img1
360
+ div = lambda a,b: torch.div(a, b, rounding_mode='trunc') # because of warning in pytorch 1.9+
361
+
362
+ # best scores in img1
363
+ score1, pos1 = corr.view(H1, W1, H2*W2).max(dim=-1)
364
+ pos1 = torch.cat((fix1(mgrid(score1, device=pos1.device)), pos1.view(-1,1)%W2, div(pos1.view(-1,1),W2)), dim=-1)
365
+
366
+ # best scores in img2
367
+ score2, pos2 = max_pool3d( corr, kernel_size=4, stride=4 )
368
+ pos2, score2 = pos2.view(-1,1), score2.squeeze()
369
+ pos2 = torch.cat((fix1(div(pos2,W2*H2)%W1), fix1(div(pos2,(W1*H2*W2))), pos2%W2, div(pos2,W2)%H2), dim=-1).float()
370
+
371
+ return (pos1, score1), (pos2, score2)
372
+
373
+
374
+ def intersection( set1_, set2_ ):
375
+ """ Returns the indices of values in set1 that are duplicated in set2
376
+ """
377
+ set1, map1 = set1_.squeeze().unique(return_inverse=True) # map1: i1 -> j1
378
+ set2 = set2_.squeeze().unique()
379
+ combined = torch.cat((set1, set2))
380
+
381
+ uniques, inverse, counts = combined.unique(return_counts=True, return_inverse=True)
382
+ # j -> u, i -> j, j -> n
383
+ # we are interested only in (j -> i) for n > 1:
384
+ # assert counts.max() <= 2, 'there were non-unique values in either set1 or set2'+bb()
385
+ # intersected_values = uniques[counts > 1]
386
+ inverse1 = inverse_mapping(inverse[:len(set1)], max_elem=len(uniques)-1)
387
+ intersected_indices1 = inverse1[counts>1]
388
+ return inverse_mapping(map1, max_elem=len(set1)-1)[intersected_indices1]
389
+
390
+
391
+ def reciprocal(self, corres1, corres2 ):
392
+ pos1, score1 = corres1
393
+ pos2, score2 = corres2
394
+ (H1, W1), (H2, W2) = score1.shape, map(lambda i: 4*i+1, score2.shape)
395
+
396
+ to_int = pos1.new_tensor((W1*H2*W2, H2*W2, W2, 1), dtype=torch.float32)
397
+ inter1 = intersection(pos1@to_int, pos2@to_int)
398
+ res = torch.cat((pos1[inter1], score1.view(-1,1)[inter1], 0*score1.view(-1,1)[inter1]), dim=-1)
399
+ return res
400
+
401
+
402
+ def max_pool3d( corr, kernel_size=4, stride=4 ):
403
+ H1, W1, H2, W2 = corr.shape
404
+ ks, st = kernel_size, stride
405
+ if corr.numel() >= 2**31 and corr.device != torch.device('cpu'):
406
+ # re-implementation due to a bug in pytorch
407
+ import core.cuda_deepm as kernels
408
+ return kernels.max_pool3d( corr.view(1, H1*W1, H2, W2), kernel_size, stride)
409
+ else:
410
+ return F.max_pool3d( corr.view(1, 1, H1*W1, H2, W2), kernel_size=(H1*W1,ks,ks), stride=(1,st,st), return_indices=True)
411
+
412
+
413
+ def forward_cuda(self, level, lower, weights=None, pooled=False):
414
+ import core.cuda_deepm as kernels # must be imported after torch_set_gpu()
415
+ assert lower.numel() < 2**31, 'please use cuda-lowmem, pytorch cannot handle big tensors'
416
+ pooled = lower if pooled else F.max_pool2d(lower, 3, padding=1, stride=2)
417
+ return kernels.forward_agg(level, self.border_inv, pooled, weights)
418
+
419
+ def forward_cuda_lowmem(self, level, lower, weights=None):
420
+ import core.cuda_deepm as kernels # must be imported after torch_set_gpu()
421
+ return kernels.forward_pool_agg(level, self.border_inv, lower, weights)
422
+
423
+ def backward_cuda(self, level, pyramid):
424
+ import core.cuda_deepm as kernels # must be imported after torch_set_gpu()
425
+ kernels.backward_agg_unpool(level, pyramid[level], pyramid[level-1], True)
426
+ # assert not torch.isnan(pyramid[level-1]).any(), bb()
427
+ return pyramid[level-1]
428
+
429
+ def merge_corres(self, corres, rots, all_corres, code):
430
+ " rot : reference --> rotated "
431
+ all_step = self.matcher.pixel_desc.get_atomic_patch_size() // 2 # step size in all_corres
432
+ dev = all_corres[0][1].device
433
+
434
+ # stack correspondences
435
+ corres = [torch.cat((p.view(*s.shape,4),s[:,:,None],torch.full_like(s[:,:,None],code)),dim=2) for (p,s) in corres]
436
+
437
+ import core.cuda_deepm as kernels # must be imported after torch_set_gpu()
438
+ kernels.merge_corres_one_side( corres[0].to(dev), 0, rots[0].to(dev), all_corres[0][1], all_step )
439
+ kernels.merge_corres_one_side( corres[1].to(dev), 2, rots[1].to(dev), all_corres[1][1], all_step )
440
+
core/losses/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from .multiloss import MultiLoss
6
+ from .pixel_ap_loss import PixelAPLoss
7
+ from .ap_loss_sampler import NghSampler
8
+ from .unsupervised_deepmatching_loss import DeepMatchingLoss
core/losses/ap_loss.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class APLoss (nn.Module):
11
+ """ differentiable AP loss, through quantization.
12
+
13
+ Input: (N, M) values in [min, max]
14
+ label: (N, M) values in {0, 1}
15
+
16
+ Returns: list of query AP (for each n in {1..N})
17
+ Note: typically, you want to minimize 1 - mean(AP)
18
+ """
19
+ def __init__(self, nq=25, min=0, max=1, euc=False):
20
+ nn.Module.__init__(self)
21
+ assert isinstance(nq, int) and 2 <= nq <= 100
22
+ self.nq = nq
23
+ self.min = min
24
+ self.max = max
25
+ self.euc = euc
26
+ gap = max - min
27
+ assert gap > 0
28
+
29
+ # init quantizer = non-learnable (fixed) convolution
30
+ self.quantizer = q = nn.Conv1d(1, 2*nq, kernel_size=1, bias=True).requires_grad_(False)
31
+ a = (nq-1) / gap
32
+ #1st half = lines passing to (min+x,1) and (min+x+1/a,0) with x = {nq-1..0}*gap/(nq-1)
33
+ q.weight.data[:nq] = -a
34
+ q.bias.data[:nq] = a*min + torch.arange(nq, 0, -1) # b = 1 + a*(min+x)
35
+ #2nd half = lines passing to (min+x,1) and (min+x-1/a,0) with x = {nq-1..0}*gap/(nq-1)
36
+ q.weight.data[nq:] = a
37
+ q.bias.data[nq:] = torch.arange(2-nq, 2, 1) - a*min # b = 1 - a*(min+x)
38
+ # first and last one are special: just horizontal straight line
39
+ q.weight.data[0] = q.weight.data[-1] = 0
40
+ q.bias.data[0] = q.bias.data[-1] = 1
41
+
42
+ def compute_AP(self, x, label):
43
+ N, M = x.shape
44
+ if self.euc: # euclidean distance in same range than similarities
45
+ x = 1 - torch.sqrt(2.001 - 2*x)
46
+
47
+ # quantize all predictions
48
+ q = self.quantizer(x.unsqueeze(1))
49
+ q = torch.min(q[:,:self.nq], q[:,self.nq:]).clamp(min=0) # N x Q x M
50
+
51
+ nbs = q.sum(dim=-1) # number of samples N x Q = c
52
+ rec = (q * label.view(N,1,M).float()).sum(dim=-1) # nb of correct samples = c+ N x Q
53
+ prec = rec.cumsum(dim=-1) / (1e-16 + nbs.cumsum(dim=-1)) # precision
54
+ rec /= rec.sum(dim=-1).unsqueeze(1) # norm in [0,1]
55
+
56
+ ap = (prec * rec).sum(dim=-1) # per-image AP
57
+ return ap
58
+
59
+ def forward(self, x, label):
60
+ assert x.shape == label.shape # N x M
61
+ return self.compute_AP(x, label)
core/losses/ap_loss_sampler.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class NghSampler (nn.Module):
14
+ """ Given dense feature maps and pixel-dense flow,
15
+ compute a subset of all correspondences and return their scores and labels.
16
+
17
+ Distance to GT => 0 ... pos_d ... neg_d ... ngh
18
+ Pixel label => + + + + + + 0 0 - - - - - - -
19
+
20
+ Subsample on query side: if > 0, regular grid
21
+ < 0, random points
22
+ In both cases, the number of query points is = W*H/subq**2
23
+ """
24
+ def __init__(self, ngh, subq=-8, subd=1, pos_d=2, neg_d=4, border=16, subd_neg=-8):
25
+ nn.Module.__init__(self)
26
+ assert 0 <= pos_d < neg_d <= (ngh if ngh else 99)
27
+ self.ngh = ngh
28
+ self.pos_d = pos_d
29
+ self.neg_d = neg_d
30
+ assert subd <= ngh or ngh == 0
31
+ assert subq != 0
32
+ self.sub_q = subq
33
+ self.sub_d = subd
34
+ self.sub_d_neg = subd_neg
35
+ if border is None: border = ngh
36
+ assert border >= ngh, 'border has to be larger than ngh'
37
+ self.border = border
38
+ self.precompute_offsets()
39
+
40
+ def precompute_offsets(self):
41
+ pos_d2 = self.pos_d**2
42
+ neg_d2 = self.neg_d**2
43
+ rad2 = self.ngh**2
44
+ rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple
45
+ pos = []
46
+ neg = []
47
+ for j in range(-rad, rad+1, self.sub_d):
48
+ for i in range(-rad, rad+1, self.sub_d):
49
+ d2 = i*i + j*j
50
+ if d2 <= pos_d2:
51
+ pos.append( (i,j) )
52
+ elif neg_d2 <= d2 <= rad2:
53
+ neg.append( (i,j) )
54
+
55
+ self.register_buffer('pos_offsets', torch.LongTensor(pos).view(-1,2).t())
56
+ self.register_buffer('neg_offsets', torch.LongTensor(neg).view(-1,2).t())
57
+
58
+ def gen_grid(self, step, aflow):
59
+ B, two, H, W = aflow.shape
60
+ dev = aflow.device
61
+ b1 = torch.arange(B, device=dev)
62
+ if step > 0:
63
+ # regular grid
64
+ x1 = torch.arange(self.border, W-self.border, step, device=dev)
65
+ y1 = torch.arange(self.border, H-self.border, step, device=dev)
66
+ H1, W1 = len(y1), len(x1)
67
+ shape = (B, H1, W1)
68
+ x1 = x1[None,None,:].expand(B,H1,W1).reshape(-1)
69
+ y1 = y1[None,:,None].expand(B,H1,W1).reshape(-1)
70
+ b1 = b1[:,None,None].expand(B,H1,W1).reshape(-1)
71
+ else:
72
+ # randomly spread
73
+ n = (H - 2*self.border) * (W - 2*self.border) // step**2
74
+ x1 = torch.randint(self.border, W-self.border, (n,), device=dev)
75
+ y1 = torch.randint(self.border, H-self.border, (n,), device=dev)
76
+ x1 = x1[None,:].expand(B,n).reshape(-1)
77
+ y1 = y1[None,:].expand(B,n).reshape(-1)
78
+ b1 = b1[:,None].expand(B,n).reshape(-1)
79
+ shape = (B, n)
80
+ return b1, y1, x1, shape
81
+
82
+ def forward(self, feats, confs, aflow, **kw):
83
+ B, two, H, W = aflow.shape
84
+ assert two == 2, bb()
85
+ feat1, conf1 = feats[0], (confs[0] if confs else None)
86
+ feat2, conf2 = feats[1], (confs[1] if confs else None)
87
+
88
+ # positions in the first image
89
+ b_, y1, x1, shape = self.gen_grid(self.sub_q, aflow)
90
+
91
+ # sample features from first image
92
+ feat1 = feat1[b_, :, y1, x1]
93
+ qconf = conf1[b_, :, y1, x1].view(shape) if confs else None
94
+
95
+ #sample GT from second image
96
+ xy2 = (aflow[b_, :, y1, x1] + 0.5).long().t()
97
+ mask = (0 <= xy2[0]) * (0 <= xy2[1]) * (xy2[0] < W) * (xy2[1] < H)
98
+ mask = mask.view(shape)
99
+
100
+ def clamp(xy):
101
+ torch.clamp(xy[0], 0, W-1, out=xy[0])
102
+ torch.clamp(xy[1], 0, H-1, out=xy[1])
103
+ return xy
104
+
105
+ # compute positive scores
106
+ xy2p = clamp(xy2[:,None,:] + self.pos_offsets[:,:,None])
107
+ pscores = torch.einsum('nk,ink->ni', feat1, feat2[b_, :, xy2p[1], xy2p[0]])
108
+
109
+ # compute negative scores
110
+ xy2n = clamp(xy2[:,None,:] + self.neg_offsets[:,:,None])
111
+ nscores = torch.einsum('nk,ink->ni', feat1, feat2[b_, :, xy2n[1], xy2n[0]])
112
+
113
+ if self.sub_d_neg:
114
+ # add distractors from a grid
115
+ b3, y3, x3 = self.gen_grid(self.sub_d_neg, aflow)[:3]
116
+ distractors = feat2[b3, :, y3, x3]
117
+ dscores = torch.einsum('nk,ik->ni', feat1, distractors)
118
+ del distractors
119
+
120
+ # remove scores that corresponds to positives or nulls
121
+ x2, y2 = xy2 = xy2.float()
122
+ xy3 = torch.stack((x3,y3)).float()
123
+ dis2 = torch.cdist((xy2+b_*512).T, (xy3+b3*512).T, compute_mode='donot_use_mm_for_euclid_dist')
124
+ dscores[dis2 < self.neg_d] = 0
125
+
126
+ scores = torch.cat((pscores, nscores, dscores), dim=1)
127
+
128
+ gt = scores.new_zeros(scores.shape, dtype=torch.uint8)
129
+ gt[:, :pscores.shape[1]] = 1
130
+
131
+ return scores, gt, mask, qconf
core/losses/multiloss.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from tools.trainer import backward
12
+
13
+
14
+ class MultiLoss (nn.Module):
15
+ """ This functions handles both supervised and unsupervised samples.
16
+ """
17
+ def __init__(self, loss_sup, loss_unsup, alpha=0.3, inner_bw=True):
18
+ super().__init__()
19
+ assert 0 <= alpha
20
+ self.alpha_sup = 1 # coef of self-supervised loss
21
+ self.loss_sup = loss_sup
22
+
23
+ self.alpha_unsup = alpha # coef of unsupervised loss
24
+ self.loss_unsup = loss_unsup
25
+
26
+ self.inner_bw = inner_bw
27
+
28
+ def forward(self, desc1, desc2, homography, **kw):
29
+ sl_sup, sl_unsup = split_batch_sup_unsup(homography, 512 if self.inner_bw else 8)
30
+
31
+ inner_bw = self.inner_bw and self.training and torch.is_grad_enabled()
32
+ if inner_bw: (desc1, desc1_), (desc2, desc2_) = pause_gradient((desc1,desc2))
33
+ kw['desc1'], kw['desc2'], kw['homography'] = desc1, desc2, homography
34
+
35
+ (sup_name, sup_loss) ,= self.loss_sup(backward_loss=inner_bw*self.alpha_sup, **{k:v[sl_sup] for k,v in kw.items()}).items()
36
+ if inner_bw and sup_loss: sup_loss = backward(sup_loss) # backward to desc1 and desc2
37
+
38
+ (uns_name, uns_loss) ,= self.loss_unsup(**{k:v[sl_unsup] for k,v in kw.items()}).items()
39
+ uns_loss = self.alpha_unsup * uns_loss
40
+ if inner_bw and uns_loss: uns_loss = backward(uns_loss) # backward to desc1 and desc2
41
+
42
+ loss = sup_loss + uns_loss
43
+ return {'loss':(loss, [(desc1_,desc1.grad),(desc2_,desc2.grad)]), sup_name:float(sup_loss), uns_name:float(uns_loss)}
44
+
45
+
46
+ def pause_gradient( objs ):
47
+ return [(obj.detach().requires_grad_(True), obj) for obj in objs]
48
+
49
+
50
+ def split_batch_sup_unsup(homography, max_sup=512):
51
+ # split batch in supervised / unsupervised
52
+ i = int(torch.isfinite(homography[:,0,0]).sum()) # first ocurence
53
+ sl_sup, sl_unsup = slice(0, min(i,max_sup)), slice(i, None)
54
+
55
+ assert torch.isfinite(homography[sl_sup]).all(), 'batch is not properly sorted!'
56
+ assert torch.isnan(homography[sl_unsup]).all(), 'batch is not properly sorted!'
57
+ return sl_sup, sl_unsup
core/losses/pixel_ap_loss.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .ap_loss import APLoss
11
+ from datasets.utils import applyh
12
+
13
+
14
+ class PixelAPLoss (nn.Module):
15
+ """ Computes the pixel-wise AP loss:
16
+ Given two images and ground-truth optical flow, computes the AP per pixel.
17
+
18
+ feat1: (B, C, H, W) pixel-wise features extracted from img1
19
+ feat2: (B, C, H, W) pixel-wise features extracted from img2
20
+ aflow: (B, 2, H, W) absolute flow: aflow[...,y1,x1] = x2,y2
21
+ """
22
+ def __init__(self, sampler, nq=20, inner_bw=False, bw_step=256):
23
+ nn.Module.__init__(self)
24
+ self.aploss = APLoss(nq, min=0, max=1, euc=False)
25
+ self.name = 'pixAP'
26
+ self.sampler = sampler
27
+ self.inner_bw = inner_bw
28
+ self.bw_step = bw_step
29
+
30
+ def loss_from_ap(self, ap, rel):
31
+ return 1 - ap
32
+
33
+ def forward(self, desc1, desc2, homography, backward_loss=None, **kw):
34
+ if len(desc1) == 0: return dict(ap_loss=0)
35
+ aflow = aflow_from_H(homography, desc1)
36
+ descriptors = (desc1, desc2)
37
+ scores, gt, msk, qconf = self.sampler(descriptors, kw.get('reliability'), aflow)
38
+
39
+ # compute pixel-wise AP
40
+ n = msk.numel()
41
+ if n == 0: return 0
42
+ scores, gt = scores.view(n,-1), gt.view(n,-1)
43
+
44
+ backward_loss = backward_loss or self.inner_bw
45
+ if self.training and torch.is_grad_enabled() and backward_loss:
46
+ # progressive loss computation and backward, low memory but slow
47
+ scores_, qconf_ = scores, qconf if qconf is not None else scores.new_ones(msk.shape)
48
+ scores = scores.detach().requires_grad_(True)
49
+ qconf = qconf_.detach().requires_grad_(True)
50
+ msk = msk.ravel()
51
+
52
+ loss = 0
53
+ for i in range(0, n, self.bw_step):
54
+ sl = slice(i, i+self.bw_step)
55
+ ap = self.aploss(scores[sl], gt[sl])
56
+ pixel_loss = self.loss_from_ap(ap, qconf.ravel()[sl] if qconf is not None else None)
57
+ l = backward_loss / msk.sum() * pixel_loss[msk[sl]].sum()
58
+ loss += float(l)
59
+ l.backward() # cumulate gradient
60
+ loss = (loss, [(scores_,scores.grad)])
61
+ if qconf_.requires_grad: loss[1].append((qconf_,qconf.grad))
62
+
63
+ else:
64
+ ap = self.aploss(scores, gt).view(msk.shape)
65
+ pixel_loss = self.loss_from_ap(ap, qconf)
66
+ loss = pixel_loss[msk].mean()
67
+
68
+ return dict(ap_loss=loss)
69
+
70
+
71
+ def make_grid(B, H, W, device ):
72
+ b = torch.arange(B, device=device).view(B,1,1).expand(B,H,W)
73
+ y = torch.arange(H, device=device).view(1,H,1).expand(B,H,W)
74
+ x = torch.arange(W, device=device).view(1,1,W).expand(B,H,W)
75
+ return b.view(B,H*W), torch.stack((x,y),dim=-1).view(B,H*W,2)
76
+
77
+
78
+ def aflow_from_H( H_1to2, feat1 ):
79
+ B, _, H, W = feat1.shape
80
+ b, pos1 = make_grid(B,H,W, feat1.device)
81
+ pos2 = applyh(H_1to2, pos1.float())
82
+ return pos2.view(B,H,W,2).permute(0,3,1,2)
core/losses/unsupervised_deepmatching_loss.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from core import functional as myF
12
+
13
+
14
+ class DeepMatchingLoss (nn.Module):
15
+ """ This loss is based on DeepMatching (IJCV'16).
16
+ atleast: (int) minimum image size at which the pyramid construction stops.
17
+ sub: (int) prior subsampling
18
+ way: (str) which way to compute the asymmetric matching ('1', '2' or '12')
19
+ border: (int) ignore pixels too close to the border
20
+ rectify_p: (float) non-linear power-rectification in DeepMatching
21
+ eps: (float) epsilon for the L1 normalization. Kinda handles unmatched pixels.
22
+ """
23
+ def __init__(self, eps=0.03, atleast=5, sub=2, way='12', border=16, rectify_p=1.5):
24
+ super().__init__()
25
+ assert way in ('1','2','12')
26
+ self.subsample = sub
27
+ self.border = border
28
+ self.way = way
29
+ self.atleast = atleast
30
+ self.rectify_p = rectify_p
31
+ self.eps = eps
32
+
33
+ self._cache = {}
34
+
35
+ def rectify(self, corr):
36
+ corr = corr.clip_(min=0)
37
+ corr = corr ** self.rectify_p
38
+ return corr
39
+
40
+ def forward(self, desc1, desc2, **kw):
41
+ # 1 --> 2
42
+ loss1 = self.forward_oneway(desc1, desc2, **kw) \
43
+ if '1' in self.way else 0
44
+
45
+ # 2 --> 1
46
+ loss2 = self.forward_oneway(desc2, desc1, **kw) \
47
+ if '2' in self.way else 0
48
+
49
+ return dict(deepm_loss=(loss1+loss2)/len(self.way))
50
+
51
+ def forward_oneway(self, desc1, desc2, dbg=(), **kw):
52
+ assert desc1.shape[:2] == desc2.shape[:2]
53
+
54
+ # prior subsampling
55
+ s = slice(self.border, -self.border or None, self.subsample)
56
+ desc1, desc2 = desc1[...,s,s], desc2[...,s,s]
57
+ desc1 = desc1[:,:,2::4,2::4] # subsample patches in 1st image
58
+ B, D, H1, W1, H2, W2 = desc1.shape + desc2.shape[-2:]
59
+ if B == 0: return 0 # empty batch
60
+
61
+ # intial 4D correlation volume
62
+ corr = torch.bmm(desc1.reshape(B,D,-1).transpose(1,2), desc2.reshape(B,D,-1)).view(B,H1,W1,H2,W2)
63
+
64
+ # build pyramid
65
+ pyramid = self.deep_matching(corr)
66
+ corr = pyramid[-1] # high-level correlation
67
+ corr = self.rectify(corr)
68
+
69
+ # L1 norm
70
+ B, H1, W1, H2, W2 = corr.shape
71
+ corr = corr / (corr.reshape(B,H1*W1,-1).sum(dim=-1).view(B,H1,W1,1,1) + self.eps)
72
+
73
+ # squared L2 norm
74
+ loss = - torch.square(corr).sum() / (B*H1*W1)
75
+ return loss
76
+
77
+ def deep_matching(self, corr):
78
+ # print(f'level=0 {corr.shape=}')
79
+ weights = None
80
+ pyramid = [corr]
81
+ for level in range(1,999):
82
+ corr, weights = self.forward_level(level, corr, weights)
83
+ pyramid.append(corr)
84
+ # print(f'{level=} {corr.shape=}')
85
+ if weights.sum() == 0: break # img1 has become too small
86
+ if min(corr.shape[-2:]) < 2*self.atleast: break # img2 has become too small
87
+ return pyramid
88
+
89
+ def forward_level(self, level, corr, weights):
90
+ B, H1, W1, H2, W2 = corr.shape
91
+
92
+ # max-pooling
93
+ pooled = F.max_pool2d(corr.view(B,H1*W1,H2,W2), 3, padding=1, stride=2)
94
+ pooled = pooled.view(B, H1, W1, *pooled.shape[-2:])
95
+
96
+ # print(f'rectifying corr at {level=}')
97
+ pooled = self.rectify(pooled)
98
+
99
+ # sparse conv
100
+ key = level, H1, W1, H2, W2
101
+ if key not in self._cache:
102
+ B, H1, W1, H2, W2 = myF.true_corr_shape(pooled.shape, level-1)
103
+ self._cache[key] = myF.children(level, H1, W1, H2, W2).to(corr.device)
104
+
105
+ return sparse_conv(level, pooled, self._cache[key], weights)
106
+
107
+
108
+ def sparse_conv(level, corr, parents, weights=None, border_norm=0.9):
109
+ B, H1, W1, H2, W2 = myF.true_corr_shape(corr.shape, level-1)
110
+ n_cache = len(parents)
111
+
112
+ # perform the sparse convolution 'manually'
113
+ # since sparse convolutions are not implemented in pytorch currently
114
+ corr = corr.view(B, -1, H2, W2)
115
+
116
+ res = corr.new_zeros((B, n_cache+1, H2, W2)) # last one = garbage channel
117
+ nrm = corr.new_full((n_cache+1, 3, 3), torch.finfo(corr.dtype).eps)
118
+ ones = nrm.new_ones((corr.shape[1], 1, 1))
119
+ ex = 1
120
+ if weights is not None:
121
+ weights = weights.view(corr.shape[1],1,1)
122
+ corr = corr * weights[None] # apply weights to correlation maps beforehand
123
+ ones *= weights
124
+
125
+ sl = lambda v: slice(0,-1 or None) if v < 0 else slice(1,None)
126
+ c = 0
127
+ for y in (-1, 1):
128
+ for x in (-1, 1):
129
+ src_layers = parents[:,c]; c+= 1
130
+ # we want to do: res += corr[src_layers] (for all children != -1)
131
+ # but we only have 'res.index_add_()' <==> res[tgt_layers] += corr
132
+ tgt_layers = myF.inverse_mapping(src_layers, max_elem=corr.shape[1], default=n_cache)[:-1]
133
+
134
+ # All of corr's channels MUST be utilized. for level>1, this doesn't hold,
135
+ # so we'll send them to a garbage channel ==> res[n_cache]
136
+ sel = myF.good_slice( tgt_layers < n_cache )
137
+
138
+ res[:,:,sl(-y),sl(-x)].index_add_(1, tgt_layers[sel], corr[:,sel,sl(y),sl(x)])
139
+ nrm[ :,sl(-y),sl(-x)].index_add_(0, tgt_layers[sel], ones[sel].expand(-1,2,2))
140
+
141
+ # normalize borders
142
+ weights = myF.norm_borders(res, nrm, norm=border_norm)[:-1]
143
+
144
+ res = res[:,:-1] # remove garbage channel
145
+ return res.view(B, H1+ex, W1+ex, *res.shape[-2:]), weights
146
+
core/pixel_desc.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms as tvf
11
+
12
+ from core.conv_mixer import ConvMixer
13
+
14
+ norm_RGB = tvf.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
15
+
16
+
17
+ class PixelDesc (nn.Module):
18
+ def __init__(self, path='models/PUMP_st.pt'):
19
+ super().__init__()
20
+ state_dict = torch.load( path, 'cpu' )
21
+ self.pixel_desc = ConvMixer(output_dim=128, hidden_dim=512, depth=7, patch_size=4, kernel_size=9).eval()
22
+ self.pixel_desc.load_state_dict(state_dict)
23
+
24
+ def configure(self, pipeline):
25
+ # hot-update of the default HOG-based pipeline
26
+ pipeline.__class__ = type(type(pipeline).__name__+'_Trained', (DescPipeline, type(pipeline)), {})
27
+ return self
28
+
29
+ def get_atomic_patch_size(self):
30
+ return 4
31
+
32
+ def forward(self, img, stride=1, offset=0):
33
+ if img.ndim == 3: img = img[None]
34
+ trf = torch.eye(3, device=img.device)
35
+
36
+ desc = self.pixel_desc( img )
37
+ desc = desc[..., offset::stride, offset::stride].contiguous() # free memory
38
+ return desc, trf
39
+
40
+
41
+ class DescPipeline:
42
+ def extract_descs(self, img1, img2, dtype=None):
43
+ # this will rotate the image if needed
44
+ img1, sca1 = self.demultiplex_img_trf(img1)
45
+ img2, sca2 = self.demultiplex_img_trf(img2)
46
+
47
+ # convert to float and normalize std
48
+ fimg1, fimg2 = [norm_RGB(img.type(dtype)/255) for img in (img1, img2)]
49
+
50
+ self.pixel_desc.type(fimg1.dtype)
51
+ desc1, trf1 = self.pixel_desc(fimg1, stride=4, offset=2)
52
+ desc2, trf2 = self.pixel_desc(fimg2)
53
+ return (img1, img2), (desc1.type(dtype), desc2.type(dtype)), (sca1@trf1, sca2@trf2)
54
+
55
+ def first_level(self, desc1, desc2, **kw):
56
+ B, C, H, W = desc1.shape
57
+ weights = desc1.permute(0, 2, 3, 1).view(H*W, C, 1, 1) # rearrange(desc1, '1 C H W -> (H W) C 1 1')
58
+ corr = F.conv2d(desc2, weights, padding=0, bias=None)[0]
59
+ norms = torch.ones(desc1.shape[-2:], device=corr.device)
60
+ return corr.view(desc1.shape[-2:]+desc2.shape[-2:]), norms
datasets/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from .image_set import *
6
+ from .web_images import RandomWebImages
7
+ from .pair_dataset import *
8
+ from .pair_loader import *
9
+ from .sfm120k import *
datasets/demo_warp/mountains_src.jpg ADDED
datasets/demo_warp/mountains_tgt.jpg ADDED
datasets/image_set.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ import os
7
+ from os.path import *
8
+ from PIL import Image
9
+
10
+
11
+ class ImageSet(object):
12
+ """ Base class for an image dataset.
13
+ """
14
+ def __init__(self, root, imgs):
15
+ self.root = root
16
+ self.imgs = imgs
17
+ assert imgs, f'Empty image set in {root}'
18
+
19
+ def init_from_folder(self, *args, **kw):
20
+ imset = ImageSet.from_folder(*args, **kw)
21
+ ImageSet.__init__(self, imset.root, imset.imgs)
22
+
23
+ def __len__(self):
24
+ return len(self.imgs)
25
+
26
+ def get_image_path(self, idx):
27
+ return os.path.join(self.root, self.imgs[idx])
28
+
29
+ def get_image(self, idx):
30
+ fname = self.get_image_path(idx)
31
+ try:
32
+ return Image.open(fname).convert('RGB')
33
+ except Exception as e:
34
+ raise IOError("Could not load image %s (reason: %s)" % (fname, str(e)))
35
+
36
+ __getitem__ = get_image
37
+
38
+ @staticmethod
39
+ def from_folder(root, exts=('.jpg','.jpeg','.png','.ppm'), recursive=False, listing=False, check_imgs=False):
40
+ """
41
+ recursive: bool or func. If a function, it must evaluate True to the directory name.
42
+ """
43
+ if listing:
44
+ if listing is True: listing = f"list_imgs{'_recursive' if recursive else ''}.txt"
45
+ flist = join(root, listing)
46
+ try: return ImageSet.from_listing(root,flist)
47
+ except IOError: print(f'>> ImageSet.from_folder(listing=True): entering {root}...')
48
+
49
+ if check_imgs is True: # default verif function
50
+ check_imgs = verify_img
51
+
52
+ for _, dirnames, dirfiles in os.walk(root):
53
+ imgs = sorted([f for f in dirfiles if f.lower().endswith(exts)])
54
+ if check_imgs: imgs = [img for img in imgs if check_imgs(join(root,img))]
55
+
56
+ if recursive:
57
+ for dirname in sorted(dirnames):
58
+ if callable(recursive) and not recursive(join(root,dirname)): continue
59
+ imset = ImageSet.from_folder(join(root,dirname), exts=exts, recursive=recursive, listing=listing, check_imgs=check_imgs)
60
+ imgs += [join(dirname,f) for f in imset.imgs]
61
+ break # recursion is handled internally
62
+
63
+ if listing:
64
+ try: open(flist,'w').write('\n'.join(imgs))
65
+ except IOError: pass # write permission denied
66
+ return ImageSet(root, imgs)
67
+
68
+ @staticmethod
69
+ def from_listing(root, list_path):
70
+ return ImageSet(root, open(list_path).read().splitlines())
71
+
72
+ def circular_pad(self, min_size):
73
+ assert self.imgs, 'cannot pad an empty image set'
74
+ while len(self.imgs) < min_size:
75
+ self.imgs += self.imgs # artifically augment size
76
+ self.imgs = self.imgs[:min_size or None]
77
+ return self
78
+
79
+ def __repr__(self):
80
+ prefix = os.path.commonprefix((self.get_image_path(0),self.get_image_path(len(self)-1)))
81
+ return f'{self.__class__.__name__}({len(self)} images from {prefix}...)'
82
+
83
+
84
+
85
+ def verify_img(path, exts=None):
86
+ if exts and not path.lower().endswith(exts): return False
87
+ try:
88
+ Image.open(path).convert('RGB') # try to open it
89
+ return True
90
+ except:
91
+ return False
datasets/pair_dataset.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ import os, os.path as osp
7
+ from tqdm import tqdm
8
+ from PIL import Image
9
+ import numpy as np
10
+ import torch
11
+
12
+ from .image_set import ImageSet
13
+ from .transforms import instanciate_transforms
14
+ from .utils import DatasetWithRng
15
+ invh = np.linalg.inv
16
+
17
+
18
+ class ImagePairs (DatasetWithRng):
19
+ """ Base class for a dataset that serves image pairs.
20
+ """
21
+ imgs = None # regular image dataset
22
+ pairs = [] # list of (idx1, idx2), ...
23
+
24
+ def __init__(self, image_set, pairs, trf=None, **rng):
25
+ assert image_set and pairs, 'empty images or pairs'
26
+ super().__init__(**rng)
27
+ self.imgs = image_set
28
+ self.pairs = pairs
29
+ self.trf = instanciate_transforms(trf, rng=self.rng)
30
+
31
+ def __len__(self):
32
+ return len(self.pairs)
33
+
34
+ def __getitem__(self, idx):
35
+ transform = self.trf or (lambda x:x)
36
+ pair = tuple(map(transform, self._load_pair(idx)))
37
+ return pair, {}
38
+
39
+ def _load_pair(self, idx):
40
+ i,j = self.pairs[idx]
41
+ img1 = self.imgs.get_image(i)
42
+ return (img1, img1) if i == j else (img1, self.imgs.get_image(j))
43
+
44
+ def __repr__(self):
45
+ return f'{self.__class__.__name__}({len(self)} pairs from {self.imgs})'
46
+
47
+
48
+ class StillImagePairs (ImagePairs):
49
+ """ A dataset of 'still' image pairs used for debugging purposes.
50
+ """
51
+ def __init__(self, image_set, pairs=None, **rng):
52
+ if isinstance(image_set, ImagePairs):
53
+ super().__init__(image_set.imgs, pairs or image_set.pairs, **rng)
54
+ else:
55
+ super().__init__(image_set, pairs or [(i,i) for i in range(len(image_set))], **rng)
56
+
57
+ def __getitem__(self, idx):
58
+ img1, img2 = self._load_pair(idx)
59
+ sx, sy = img2.size / np.float32(img1.size)
60
+ return (img1, img2), dict(homography=np.diag(np.float32([sx, sy, 1])))
61
+
62
+
63
+ class SyntheticImagePairs (StillImagePairs):
64
+ """ A synthetic generator of image pairs.
65
+ Given a normal image dataset, it constructs pairs using random homographies & noise.
66
+
67
+ scale: prior image scaling.
68
+ distort: distortion applied independently to (img1,img2) if sym=True else just img2
69
+ sym: (bool) see above.
70
+ """
71
+ def __init__(self, image_set, scale='', distort='', sym=False, **rng):
72
+ super().__init__(image_set, **rng)
73
+ self.symmetric = sym
74
+ self.scale = instanciate_transforms(scale, rng=self.rng)
75
+ self.distort = instanciate_transforms(distort, rng=self.rng)
76
+
77
+ def __getitem__(self, idx):
78
+ (img1, img2), gt = super().__getitem__(idx)
79
+
80
+ img1 = dict(img=img1, homography=np.eye(3,dtype=np.float32))
81
+ if img1['img'] is img2:
82
+ img1 = self.scale(img1)
83
+ img2 = self.distort(dict(img1))
84
+ if self.symmetric: img1 = self.distort(img1)
85
+ else:
86
+ if self.symmetric: img1 = self.distort(self.scale(img1))
87
+ img2 = self.distort(self.scale(dict(img=img2, **gt)))
88
+
89
+ return (img1['img'], img2['img']), dict(homography=img2['homography'] @ invh(img1['homography']))
90
+
91
+ def __repr__(self):
92
+ format = lambda s: ','.join(l.strip() for l in repr(s).splitlines() if l).replace(',','',1)
93
+ return f"{self.__class__.__name__}({len(self)} images, scale={format(self.scale)}, distort={format(self.distort)})"
94
+
95
+
96
+ class CatImagePairs (DatasetWithRng):
97
+ """ Concatenation of several ImagePairs datasets
98
+ """
99
+ def __init__(self, *pair_datasets, seed=torch.initial_seed()):
100
+ assert all(isinstance(db, ImagePairs) for db in pair_datasets)
101
+ self.pair_datasets = pair_datasets
102
+ DatasetWithRng.__init__(self, seed=seed) # init last
103
+ self._init()
104
+
105
+ def _init(self):
106
+ self._pair_offsets = np.cumsum([0] + [len(db) for db in self.pair_datasets])
107
+ self.npairs = self._pair_offsets[-1]
108
+
109
+ def __len__(self):
110
+ return self.npairs
111
+
112
+ def __repr__(self):
113
+ fmt_str = f"{type(self).__name__}({len(self)} pairs,"
114
+ for i,db in enumerate(self.pair_datasets):
115
+ npairs = self._pair_offsets[i+1] - self._pair_offsets[i]
116
+ fmt_str += f'\n\t{npairs} from '+str(db).replace("\n"," ") + ','
117
+ return fmt_str[:-1] + ')'
118
+
119
+ def __getitem__(self, idx):
120
+ b, i = self._which(idx)
121
+ return self.pair_datasets[b].__getitem__(i)
122
+
123
+ def _which(self, i):
124
+ pos = np.searchsorted(self._pair_offsets, i, side='right')-1
125
+ assert pos < self.npairs, 'Bad pair index %d >= %d' % (i, self.npairs)
126
+ return pos, i - self._pair_offsets[pos]
127
+
128
+ def _call(self, func, i, *args, **kwargs):
129
+ b, j = self._which(i)
130
+ return getattr(self.pair_datasets[b], func)(j, *args, **kwargs)
131
+
132
+ def init_worker(self, tid):
133
+ for db in self.pair_datasets:
134
+ db.init_worker(tid)
135
+
136
+
137
+ class BalancedCatImagePairs (CatImagePairs):
138
+ """ Balanced concatenation of several ImagePairs datasets
139
+ """
140
+ def __init__(self, npairs=0, *pair_datasets, **kw):
141
+ assert isinstance(npairs, int) and npairs >= 0, 'BalancedCatImagePairs(npairs != int)'
142
+ assert len(pair_datasets) > 0, 'no dataset provided'
143
+
144
+ if len(pair_datasets) >= 3 and isinstance(pair_datasets[1], int):
145
+ assert len(pair_datasets) % 2 == 1
146
+ pair_datasets = [npairs] + list(pair_datasets)
147
+ npairs, pair_datasets = pair_datasets[0::2], pair_datasets[1::2]
148
+ assert all(isinstance(n, int) for n in npairs)
149
+ self._pair_offsets = np.cumsum([0]+npairs)
150
+ self.npairs = self._pair_offsets[-1]
151
+ else:
152
+ self.npairs = npairs or max(len(db) for db in pair_datasets)
153
+ self._pair_offsets = np.linspace(0, self.npairs, len(pair_datasets)+1).astype(int)
154
+ CatImagePairs.__init__(self, *pair_datasets, **kw)
155
+
156
+ def set_epoch(self, epoch):
157
+ DatasetWithRng.init_worker(self, epoch) # random seed only depends on the epoch
158
+ self._init() # reset permutations for this epoch
159
+
160
+ def init_worker(self, tid):
161
+ CatImagePairs.init_worker(self, tid)
162
+
163
+ def _init(self):
164
+ self._perms = []
165
+ for i,db in enumerate(self.pair_datasets):
166
+ assert len(db), 'cannot balance if there is an empty dataset'
167
+ avail = self._pair_offsets[i+1] - self._pair_offsets[i]
168
+ idxs = np.arange(len(db))
169
+ while len(idxs) < avail:
170
+ idxs = np.r_[idxs,idxs]
171
+ if self.seed: # if not seed, then no shuffle
172
+ self.rng.shuffle(idxs[(avail//len(db))*len(db):])
173
+ self._perms.append( idxs[:avail] )
174
+ # print(self._perms)
175
+
176
+ def _which(self, i):
177
+ pos, idx = super()._which(i)
178
+ return pos, self._perms[pos][idx]
179
+
180
+
181
+ class UnsupervisedPairs (ImagePairs):
182
+ """ Unsupervised image pairs obtained from SfM
183
+ """
184
+ def __init__(self, img_set, pair_file_path):
185
+ assert isinstance(img_set, ImageSet), bb()
186
+ self.pair_list = self._parse_pair_list(pair_file_path)
187
+ self.corres_dir = osp.join(osp.split(pair_file_path)[0], 'corres')
188
+
189
+ tag_to_idx = {n:i for i,n in enumerate(img_set.imgs)}
190
+ img_indices = lambda pair: tuple([tag_to_idx[n] for n in pair])
191
+ super().__init__(img_set, [img_indices(pair) for pair in self.pair_list])
192
+
193
+ def __repr__(self):
194
+ return f"{type(self).__name__}({len(self)} pairs from {self.imgs})"
195
+
196
+ def _parse_pair_list(self, pair_file_path):
197
+ res = []
198
+ for row in open(pair_file_path).read().splitlines():
199
+ row = row.split()
200
+ if len(row) != 2: raise IOError()
201
+ res.append((row[0], row[1]))
202
+ return res
203
+
204
+ def get_corres_path(self, pair_idx):
205
+ img1, img2 = [osp.basename(self.imgs.imgs[i]) for i in self.pairs[pair_idx]]
206
+ return osp.join(self.corres_dir, f'{img1}_{img2}.npy')
207
+
208
+ def get_corres(self, pair_idx):
209
+ return np.load(self.get_corres_path(pair_idx))
210
+
211
+ def __getitem__(self, idx):
212
+ img1, img2 = self._load_pair(idx)
213
+ return (img1, img2), dict(corres=self.get_corres(idx))
214
+
215
+
216
+ if __name__ == '__main__':
217
+ from datasets import *
218
+ from tools.viz import show_random_pairs
219
+
220
+ db = BalancedCatImagePairs(
221
+ 3125, SyntheticImagePairs(RandomWebImages(0,52),distort='RandomTilting(0.5)'),
222
+ 4875, SyntheticImagePairs(SfM120k_Images(),distort='RandomTilting(0.5)'),
223
+ 8000, SfM120k_Pairs())
224
+
225
+ show_random_pairs(db)
226
+
datasets/pair_loader.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ from PIL import Image
7
+ import numpy as np
8
+
9
+ from core import functional as myF
10
+ from tools.common import todevice
11
+ from .transforms import instanciate_transforms
12
+ from .utils import *
13
+
14
+
15
+ class FastPairLoader (DatasetWithRng):
16
+ """ On-the-fly generation of related image pairs
17
+ crop: random crop applied to both images
18
+ scale: random scaling applied to img2
19
+ distort: random ditorsion applied to img2
20
+
21
+ self[idx] returns: (img1, img2), dict(homography=)
22
+ (homography: 3x3 array, can be nan)
23
+ """
24
+ def __init__(self, dataset, crop=256, transform='', p_flip=0, p_swap=0, scale_jitter=0, seed=None):
25
+ super().__init__(seed)
26
+ self.dataset = self.with_same_rng(dataset)
27
+ self.transform = instanciate_transforms( transform, rng=self.rng )
28
+ self.crop_size = crop
29
+ self.p_swap = p_swap
30
+ self.p_flip = p_flip
31
+ self.scale_jitter = abs(np.log1p(scale_jitter))
32
+
33
+ def __len__(self):
34
+ return len(self.dataset)
35
+
36
+ def __repr__(self):
37
+ fmt_str = f'FastPairLoader({self.dataset},\n'
38
+ short_repr = lambda s: repr(s).strip().replace('\n',', ')[14:-1].replace(' ',' ')
39
+ fmt_str += ' Transform:\t%s\n' % short_repr(self.transform)
40
+ fmt_str +=f' Crop={self.crop_size}, scale_jitter=x{np.exp(self.scale_jitter):g}, p_swap={self.p_swap:g}'
41
+ return fmt_str
42
+
43
+ def init_worker(self, tid):
44
+ super().init_worker(tid)
45
+ self.dataset.init_worker(tid)
46
+
47
+ def set_epoch(self, epoch):
48
+ self.dataset.set_epoch(epoch)
49
+
50
+ def __getitem__(self, idx):
51
+ self.init_worker(idx) # preserve RNG for this pair
52
+ (img1, img2), gt = self.dataset[idx]
53
+
54
+ if self.rng.random() < self.p_swap:
55
+ img1, img2 = img2, img1
56
+ if 'homography' in gt: gt['homography'] = invh(gt['homography'])
57
+ if 'corres' in gt: gt['corres'] = swap_corres(gt['corres'])
58
+
59
+ if self.rng.random() < self.p_flip:
60
+ img1, img2, gt = flip_image_pair(img1, img2, gt)
61
+
62
+ # apply transformations to the second image
63
+ img2 = self.transform(dict(img=img2))
64
+
65
+ homography, corres = spatial_relationship( img1, img2, gt )
66
+
67
+ # find a good window
68
+ img1, img2 = map(self._pad_rgb_numpy, (img1, img2['img']))
69
+
70
+ if not 'debug':
71
+ from tools.viz import show_correspondences
72
+ print(np.median(corres[:,5]))
73
+ show_correspondences(img1, img2, corres, bb=bb)
74
+
75
+ def windows_from_corres( idx, scale_jitter=1 ):
76
+ c = corres[idx]
77
+ p1, p2, scale = c[0:2], c[2:4], c[6]
78
+ scale *= scale_jitter
79
+
80
+ # make windows based on scaling
81
+ win1 = window(*p1, self.crop_size, max(1, 1/scale), img1.shape)
82
+ win2 = window(*p2, self.crop_size, max(1, scale/1), img2.shape)
83
+ return win1, win2
84
+
85
+ best = 0, None
86
+ for idx in self.rng.choice(len(corres), size=min(len(corres),5), replace=False):
87
+ # pick a correspondence at random
88
+ win1, win2 = windows_from_corres( idx )
89
+
90
+ # check how many matches are in the 2 windows
91
+ score = score_windows(is_in(corres[:,0:2],win1), is_in(corres[:,2:4],win2))
92
+ if score > best[0]: best = score, idx
93
+
94
+ others = {}
95
+ if None in best: # counldn't find a good window
96
+ img1 = img2 = np.zeros((self.crop_size,self.crop_size,3), dtype=np.uint8)
97
+ corres = np.empty((0, 6), dtype=np.float32)
98
+ else:
99
+ # jitter scales
100
+ scale_jitter = np.exp(self.rng.uniform(-self.scale_jitter, self.scale_jitter))
101
+ win1, win2 = windows_from_corres( best[1], scale_jitter )
102
+ # print(win1, win2, img1.shape, img2.shape)
103
+ img1, img2 = imresize(img1[win1], self.crop_size), imresize(img2[win2], self.crop_size)
104
+ trf1, trf2 = wintrf(win1, img1), wintrf(win2, img2)
105
+
106
+ # fix rotation if necessary
107
+ angle_scores = np.bincount(corres[:,5].astype(int) % 8)
108
+ rot90 = int((((angle_scores.argmax() + 4) % 8) - 4) / 2)
109
+ if rot90: # rectify rotation
110
+ img2, trf = myF.rotate_img_90((img2, np.eye(3)), 90*rot90)
111
+ trf2 = invh(trf) @ trf2
112
+
113
+ homography = trf2 @ homography @ invh(trf1)
114
+ corres = myF.affmul((trf1,trf2), corres)
115
+
116
+ f32c = lambda i,**kw: np.require(i, requirements='CWAE', **kw)
117
+ return (f32c(img1), f32c(img2)), dict(homography = f32c(homography, dtype=np.float32), corres=corres, **others)
118
+
119
+ def _pad_rgb_numpy(self, img):
120
+ if img.mode != 'RGB':
121
+ img = img.convert('RGB')
122
+ if min(img.size) < self.crop_size:
123
+ w, h = img.size
124
+ result = Image.new('RGB', (max(w,self.crop_size), max(h,self.crop_size)), 0)
125
+ result.paste(img, (0, 0))
126
+ img = result
127
+ return np.asarray(img)
128
+
129
+
130
+
131
+ def swap_corres( corres ): # swap img1 and img2
132
+ res = corres.copy()
133
+ res[:,[0,1,2,3]] = corres[:,[2,3,0,1]]
134
+ if corres.shape[1] > 4: # invert rotation and scale
135
+ scale, rot = myF.decode_scale_rot(corres[:,5])
136
+ res[:,5] = myF.encode_scale_rot(1/scale, -rot)
137
+ return res
138
+
139
+ def flip(img):
140
+ w, h = img.size
141
+ return img.transpose(Image.FLIP_LEFT_RIGHT), np.float32( [[-1,0,w-1],[0,1,0],[0,0,1]] )
142
+
143
+ def flip_image_pair(img1, img2, gt):
144
+ img1, F1 = flip(img1)
145
+ img2, F2 = flip(img2)
146
+ res = {}
147
+ for key, value in gt.items():
148
+ if key == 'homography':
149
+ res['homography'] = F2 @ value @ F1
150
+ elif key == 'aflow':
151
+ assert False, 'flip for aflow: todo'
152
+ elif key == 'corres':
153
+ new_corres = np.c_[applyh(F1,value[:,0:2]), applyh(F2,value[:,2:4])]
154
+ if value.shape[1] == 4: pass
155
+ elif value.shape[1] == 6:
156
+ scale, rot = myF.decode_scale_rot(value[:,5])
157
+ new_code = myF.encode_scale_rot(scale, -rot)
158
+ new_corres = np.c_[new_corres,value[:,4],new_code]
159
+ res['corres'] = new_corres
160
+ else:
161
+ raise ValueError(f"flip_image_pair: bad gt field '{key}'")
162
+ return img1, img2, res
163
+
164
+
165
+ def spatial_relationship( img1, img2, gt ):
166
+ if 'homography' in gt:
167
+ homography = gt['homography']
168
+ if 'homography' in img2:
169
+ homography = np.float32(img2['homography']) @ homography
170
+ corres = corres_from_homography(homography, *img1.size)
171
+
172
+ elif 'corres' in gt:
173
+ homography = np.full((3,3), np.nan, dtype=np.float32)
174
+ corres = gt['corres']
175
+ if 'homography' in img2:
176
+ corres[:,2:4] = applyh(img2['homography'], corres[:,2:4])
177
+ else:
178
+ img2['homography'] = np.eye(3)
179
+ scales = np.sqrt(np.abs(np.linalg.det(jacobianh(img2['homography'], corres[:,0:2]).T)))
180
+
181
+ if corres.shape[1] == 4:
182
+ scales, rots = scale_rot_from_corres(corres)
183
+ corres = np.c_[corres, np.ones_like(scales), myF.encode_scale_rot(scales,rots*180/np.pi), scales]
184
+ elif corres.shape[1] == 6:
185
+ corres = np.c_[corres, scales * myF.decode_scale_rot(corres[:,5])[0]]
186
+ else:
187
+ assert ValueError(f'bad shape for corres: {corres.shape}')
188
+
189
+ return homography, corres
190
+
191
+
192
+ def scale_rot_from_corres( corres, sub=256, nn=16 ):
193
+ # select a subset of relevant correspondences
194
+ sub = np.random.choice(len(corres), size=min(len(corres),sub), replace=False)
195
+ sub = corres[sub]
196
+
197
+ # for each corres, find the scale change w.r.t. its NNs
198
+ from scipy.spatial.distance import cdist
199
+ nns = cdist(corres, sub, metric='sqeuclidean').argsort(axis=1)[:,:nn]
200
+
201
+ # affine transform for this set of neighboring correspondences
202
+ pts = sub[nns] # shape = npts x sub x 4
203
+ # [P1,1] @ A = P2 with A = 3x2 matrix
204
+ # A = [P1,1]^-1 @ P2
205
+ P1, P2 = pts[:,:,0:2], pts[:,:,2:4] # each row = list of correspondences
206
+ P1 = np.concatenate((P1,np.ones_like(P1[:,:,:1])),axis=-1)
207
+ A = (np.linalg.pinv(P1) @ P2).transpose(0,2,1)
208
+
209
+ scale, (angy,angx) = detect_scale_rotation(A.transpose(1,2,0)[:,1::-1])
210
+ rot = np.arctan2(angy, angx)
211
+ return scale.clip(min=0.2, max=5), rot
212
+
213
+
214
+ def window1(x, size, w):
215
+ l = x - int(0.5 + size / 2)
216
+ r = l + int(0.5 + size)
217
+ if l < 0: l,r = (0, r - l)
218
+ if r > w: l,r = (l + w - r, w)
219
+ if l < 0: l,r = 0,w # larger than width
220
+ return slice(l,r)
221
+
222
+ def window(cx, cy, win_size, scale, img_shape):
223
+ return (window1(int(cy), win_size*scale, img_shape[0]),
224
+ window1(int(cx), win_size*scale, img_shape[1]))
225
+
226
+ def is_in( pts, window ):
227
+ x, y = pts.T
228
+ sly, slx = window
229
+ return (slx.start <= x) & (x < slx.stop) & (sly.start <= y) & (y < sly.stop)
230
+
231
+ def score_windows( valid1, valid2 ):
232
+ inter = (valid1 & valid2).sum()
233
+ iou1 = inter / (valid1.sum() + 1e-8)
234
+ iou2 = inter / (valid2.sum() + 1e-8)
235
+ return inter * min(iou1, iou2)
236
+
237
+ def imresize( img, max_size, resample=Image.ANTIALIAS):
238
+ if max(img.shape[:2]) > max_size:
239
+ if img.shape[-1] == 2:
240
+ img = np.stack([np.float32(Image.fromarray(img[...,i]).resize((max_size,max_size), resample=resample)) for i in range(2)], axis=-1)
241
+ else:
242
+ img = np.asarray(Image.fromarray(img).resize((max_size,max_size), resample=resample))
243
+ assert img.shape[0] == img.shape[1] == max_size, bb()
244
+ return img
245
+
246
+ def wintrf( window, final_img ):
247
+ wy, wx = window
248
+ H, W = final_img.shape[:2]
249
+ T = np.float32((((wx.stop-wx.start)/W, 0, wx.start),
250
+ (0, (wy.stop-wy.start)/H, wy.start),
251
+ (0, 0, 1)) )
252
+ return invh(T)
253
+
254
+
255
+ def collate_ordered(batch, _use_shared_memory=True):
256
+ pairs, gt = zip(*batch)
257
+ imgs1, imgs2 = zip(*pairs)
258
+ assert len(imgs1) == len(imgs2) == len(gt) and isinstance(gt[0], dict)
259
+
260
+ # reorder samples (supervised ones first, unsupervised ones last)
261
+ supervised = [i for i,b in enumerate(gt) if np.isfinite(b['homography']).all()]
262
+ unsupervsd = [i for i,b in enumerate(gt) if np.isnan(b['homography']).any()]
263
+ order = supervised + unsupervsd
264
+
265
+ def collate( tensors, key=None ):
266
+ import torch
267
+ batch = todevice([tensors[i] for i in order], 'cpu')
268
+ if key == 'corres': return batch # cannot concat
269
+ if _use_shared_memory: # shared memory tensor to avoid an extra copy
270
+ numel = sum([x.numel() for x in batch])
271
+ storage = batch[0].storage()._new_shared(numel)
272
+ out = batch[0].new(storage)
273
+ return torch.stack(batch, dim=0, out=out)
274
+
275
+ return (collate(imgs1), collate(imgs2)), {k:collate([b[k] for b in gt],k) for k in gt[0]}
276
+
277
+
278
+ if __name__ == '__main__':
279
+ from datasets import *
280
+ from tools.viz import show_random_pairs
281
+
282
+ db = BalancedCatImagePairs(
283
+ 3125, SyntheticImagePairs(RandomWebImages(0,52),distort='RandomTilting(0.5)'),
284
+ 4875, SyntheticImagePairs(SfM120k_Images(),distort='RandomTilting(0.5)'),
285
+ 8000, SfM120k_Pairs())
286
+
287
+ db = FastPairLoader(db,
288
+ crop=256, transform='RandomRotation(20), RandomScale(256,1536,ar=1.3,can_upscale=True), PixelNoise()',
289
+ p_swap=0.5, p_flip=0.5, scale_jitter=0, seed=777)
290
+
291
+ show_random_pairs(db)
datasets/sfm120k.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ from os.path import *
7
+
8
+ from .image_set import ImageSet
9
+ from .pair_dataset import UnsupervisedPairs
10
+
11
+
12
+ class SfM120k_Images (ImageSet):
13
+ def __init__(self, root='datasets/sfm120k'):
14
+ self.init_from_folder(join(root,'ims'), recursive=True, listing=True, exts='')
15
+
16
+
17
+ class SfM120k_Pairs (UnsupervisedPairs):
18
+ def __init__(self, root='datasets/sfm120k'):
19
+ super().__init__(SfM120k_Images(root=root), join(root,'list_pairs.txt'))
20
+
21
+
22
+ if __name__ == '__main__':
23
+ from tools.viz import show_random_pairs
24
+
25
+ db = SfM120k_Pairs()
26
+
27
+ show_random_pairs(db)
datasets/transforms.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ import warnings
7
+
8
+ import numpy as np
9
+ from PIL import Image, ImageOps
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torchvision import transforms as tvf
14
+
15
+ from . import transforms_tools as F
16
+ from .utils import DatasetWithRng
17
+
18
+ '''
19
+ Example command to try out some transformation chain:
20
+
21
+ python -m pytools.transforms --trfs "Scale(384), ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1), RandomRotation(10), RandomTilting(0.5, 'all'), RandomScale(240,320), RandomCrop(224)"
22
+ '''
23
+
24
+ def instanciate_transforms(transforms, use_gpu=False, rng=None, compose=True):
25
+ ''' Instanciate a sequence of transformations.
26
+
27
+ transforms: (str, list)
28
+ Comma-separated list of transformations.
29
+ Ex: "Rotate(10), Scale(256)"
30
+ '''
31
+ try:
32
+ transforms = transforms or '[]'
33
+
34
+ if isinstance(transforms, str):
35
+ if transforms.lstrip()[0] not in '[(': transforms = f'[{transforms}]'
36
+ if compose: transforms = f'Compose({transforms})'
37
+ transforms = eval(transforms)
38
+
39
+ if isinstance(transforms, list) and transforms and isinstance(transforms[0], str):
40
+ transforms = [eval(trf) for trf in transforms]
41
+ if compose: transforms = Compose(transforms)
42
+
43
+ if use_gpu and not isinstance(transforms, nn.Module):
44
+ while hasattr(transforms,'transforms') or hasattr(transforms,'transform'):
45
+ transforms = getattr(transforms,'transforms',getattr(transforms,'transform',None))
46
+ transforms = [trf for trf in transforms if isinstance(trf, nn.Module)]
47
+ transforms = nn.Sequential(*transforms) if compose else nn.ModuleList(transforms)
48
+
49
+ if transforms and rng:
50
+ for trf in transforms.transforms:
51
+ assert hasattr(trf, 'rng'), f"Transformation {trf} has no self.rng"
52
+ trf.rng = rng
53
+
54
+ if isinstance(transforms, Compose) and len(transforms.transforms) == 1:
55
+ transforms = transforms.transforms[0]
56
+ return transforms
57
+
58
+ except Exception as e:
59
+ print("\nError: Cannot interpret this transform list: %s\n" % transforms)
60
+ raise e
61
+
62
+
63
+
64
+ class Compose (DatasetWithRng):
65
+ def __init__(self, transforms, **rng_seed):
66
+ super().__init__(**rng_seed)
67
+ self.transforms = [self.with_same_rng(trf) for trf in transforms]
68
+
69
+ def __call__(self, data):
70
+ for trf in self.transforms:
71
+ data = trf(data)
72
+ return data
73
+
74
+
75
+ class Scale (DatasetWithRng):
76
+ """ Rescale the input PIL.Image to a given size.
77
+ Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
78
+
79
+ The smallest dimension of the resulting image will be = size.
80
+
81
+ if largest == True: same behaviour for the largest dimension.
82
+
83
+ if not can_upscale: don't upscale
84
+ if not can_downscale: don't downscale
85
+ """
86
+ def __init__(self, size, interpolation=Image.BILINEAR, largest=False,
87
+ can_upscale=True, can_downscale=True, **rng_seed):
88
+ super().__init__(**rng_seed)
89
+ assert isinstance(size, int) or (len(size) == 2)
90
+ self.size = size
91
+ self.interpolation = interpolation
92
+ self.largest = largest
93
+ self.can_upscale = can_upscale
94
+ self.can_downscale = can_downscale
95
+
96
+ def __repr__(self):
97
+ fmt_str = "RandomScale(%s" % str(self.size)
98
+ if self.largest: fmt_str += ', largest=True'
99
+ if not self.can_upscale: fmt_str += ', can_upscale=False'
100
+ if not self.can_downscale: fmt_str += ', can_downscale=False'
101
+ return fmt_str+')'
102
+
103
+ def get_params(self, imsize):
104
+ w,h = imsize
105
+ if isinstance(self.size, int):
106
+ cmp = lambda a,b: (a>=b) if self.largest else (a<=b)
107
+ if (cmp(w, h) and w == self.size) or (cmp(h, w) and h == self.size):
108
+ ow, oh = w, h
109
+ elif cmp(w, h):
110
+ ow = self.size
111
+ oh = int(self.size * h / w)
112
+ else:
113
+ oh = self.size
114
+ ow = int(self.size * w / h)
115
+ else:
116
+ ow, oh = self.size
117
+ return ow, oh
118
+
119
+ def __call__(self, inp):
120
+ img = F.grab(inp,'img')
121
+ w, h = img.size
122
+
123
+ size2 = ow, oh = self.get_params(img.size)
124
+
125
+ if size2 != img.size:
126
+ a1, a2 = img.size, size2
127
+ if (self.can_upscale and min(a1) < min(a2)) or (self.can_downscale and min(a1) > min(a2)):
128
+ img = img.resize(size2, self.interpolation)
129
+
130
+ return F.update(inp, img=img, homography=np.diag((ow/w,oh/h,1)))
131
+
132
+
133
+
134
+ class RandomScale (Scale):
135
+ """Rescale the input PIL.Image to a random size.
136
+ Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
137
+
138
+ Args:
139
+ min_size (int): min size of the smaller edge of the picture.
140
+ max_size (int): max size of the smaller edge of the picture.
141
+
142
+ ar (float or tuple):
143
+ max change of aspect ratio (width/height).
144
+
145
+ interpolation (int, optional): Desired interpolation. Default is
146
+ ``PIL.Image.BILINEAR``
147
+ """
148
+
149
+ def __init__(self, min_size, max_size, ar=1, larger=False,
150
+ can_upscale=False, can_downscale=True, interpolation=Image.BILINEAR):
151
+ Scale.__init__(self, (min_size,max_size), can_upscale=can_upscale, can_downscale=can_downscale, interpolation=interpolation)
152
+ assert type(min_size) == type(max_size), 'min_size and max_size can only be 2 ints or 2 floats'
153
+ assert isinstance(min_size, int) and min_size >= 1 or isinstance(min_size, float) and min_size>0
154
+ assert isinstance(max_size, (int,float)) and min_size <= max_size
155
+ self.min_size = min_size
156
+ self.max_size = max_size
157
+ if type(ar) in (float,int): ar = (min(1/ar,ar),max(1/ar,ar))
158
+ assert 0.2 < ar[0] <= ar[1] < 5
159
+ self.ar = ar
160
+ self.larger = larger
161
+
162
+ def get_params(self, imsize):
163
+ w,h = imsize
164
+ if isinstance(self.min_size, float): min_size = int(self.min_size*min(w,h) + 0.5)
165
+ if isinstance(self.max_size, float): max_size = int(self.max_size*min(w,h) + 0.5)
166
+ if isinstance(self.min_size, int): min_size = self.min_size
167
+ if isinstance(self.max_size, int): max_size = self.max_size
168
+
169
+ if not(self.can_upscale) and not(self.larger):
170
+ max_size = min(max_size,min(w,h))
171
+
172
+ size = int(0.5 + F.rand_log_uniform(self.rng, min_size, max_size))
173
+ if not(self.can_upscale) and self.larger:
174
+ size = min(size, min(w,h))
175
+
176
+ ar = F.rand_log_uniform(self.rng, *self.ar) # change of aspect ratio
177
+
178
+ if w < h: # image is taller
179
+ ow = size
180
+ oh = int(0.5 + size * h / w / ar)
181
+ if oh < min_size:
182
+ ow,oh = int(0.5 + ow*float(min_size)/oh),min_size
183
+ else: # image is wider
184
+ oh = size
185
+ ow = int(0.5 + size * w / h * ar)
186
+ if ow < min_size:
187
+ ow,oh = min_size,int(0.5 + oh*float(min_size)/ow)
188
+
189
+ assert ow >= min_size, 'image too small (width=%d < min_size=%d)' % (ow, min_size)
190
+ assert oh >= min_size, 'image too small (height=%d < min_size=%d)' % (oh, min_size)
191
+ return ow, oh
192
+
193
+
194
+
195
+ class RandomCrop (DatasetWithRng):
196
+ """Crop the given PIL Image at a random location.
197
+ Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
198
+
199
+ Args:
200
+ size (sequence or int): Desired output size of the crop. If size is an
201
+ int instead of sequence like (h, w), a square crop (size, size) is
202
+ made.
203
+ padding (int or sequence, optional): Optional padding on each border
204
+ of the image. Default is 0, i.e no padding. If a sequence of length
205
+ 4 is provided, it is used to pad left, top, right, bottom borders
206
+ respectively.
207
+ """
208
+
209
+ def __init__(self, size, padding=0, **rng_seed):
210
+ super().__init__(**rng_seed)
211
+ if isinstance(size, int):
212
+ self.size = (int(size), int(size))
213
+ else:
214
+ self.size = size
215
+ self.padding = padding
216
+
217
+ def __repr__(self):
218
+ return "RandomCrop(%s)" % str(self.size)
219
+
220
+ def get_params(self, img, output_size):
221
+ w, h = img.size
222
+ th, tw = output_size
223
+ assert h >= th and w >= tw, "Image of %dx%d is too small for crop %dx%d" % (w,h,tw,th)
224
+
225
+ y = self.rng.integers(0, h - th) if h > th else 0
226
+ x = self.rng.integers(0, w - tw) if w > tw else 0
227
+ return x, y, tw, th
228
+
229
+ def __call__(self, inp):
230
+ img = F.grab(inp,'img')
231
+
232
+ padl = padt = 0
233
+ if self.padding:
234
+ if F.is_pil_image(img):
235
+ img = ImageOps.expand(img, border=self.padding, fill=0)
236
+ else:
237
+ assert isinstance(img, F.DummyImg)
238
+ img = img.expand(border=self.padding)
239
+ if isinstance(self.padding, int):
240
+ padl = padt = self.padding
241
+ else:
242
+ padl, padt = self.padding[0:2]
243
+
244
+ i, j, tw, th = self.get_params(img, self.size)
245
+ img = img.crop((i, j, i+tw, j+th))
246
+
247
+ return F.update(inp, img=img, homography=np.float32(((1,0,padl-i),(0,1,padt-j),(0,0,1))))
248
+
249
+
250
+ class CenterCrop (RandomCrop):
251
+ """Crops the given PIL Image at the center.
252
+ Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
253
+
254
+ Args:
255
+ size (sequence or int): Desired output size of the crop. If size is an
256
+ int instead of sequence like (h, w), a square crop (size, size) is
257
+ made.
258
+ """
259
+ @staticmethod
260
+ def get_params(img, output_size):
261
+ w, h = img.size
262
+ th, tw = output_size
263
+ y = int(0.5 +((h - th) / 2.))
264
+ x = int(0.5 +((w - tw) / 2.))
265
+ return x, y, tw, th
266
+
267
+
268
+ class RandomRotation (DatasetWithRng):
269
+ """Rescale the input PIL.Image to a random size.
270
+ Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
271
+
272
+ Args:
273
+ degrees (float):
274
+ rotation angle.
275
+
276
+ interpolation (int, optional): Desired interpolation. Default is
277
+ ``PIL.Image.BILINEAR``
278
+ """
279
+
280
+ def __init__(self, degrees, interpolation=Image.BILINEAR, **rng_seed):
281
+ super().__init__(**rng_seed)
282
+ self.degrees = degrees
283
+ self.interpolation = interpolation
284
+
285
+ def __repr__(self):
286
+ return f"RandomRotation({self.degrees})"
287
+
288
+ def __call__(self, inp):
289
+ img = F.grab(inp,'img')
290
+ w, h = img.size
291
+
292
+ angle = self.rng.uniform(-self.degrees, self.degrees)
293
+
294
+ img = img.rotate(angle, resample=self.interpolation)
295
+ w2, h2 = img.size
296
+
297
+ trf = F.translate(w2/2,h2/2) @ F.rotate(-angle * np.pi/180) @ F.translate(-w/2,-h/2)
298
+ return F.update(inp, img=img, homography=trf)
299
+
300
+
301
+ class RandomTilting (DatasetWithRng):
302
+ """Apply a random tilting (left, right, up, down) to the input PIL.Image
303
+ Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
304
+
305
+ Args:
306
+ maginitude (float):
307
+ maximum magnitude of the random skew (value between 0 and 1)
308
+ directions (string):
309
+ tilting directions allowed (all, left, right, up, down)
310
+ examples: "all", "left,right", "up-down-right"
311
+ """
312
+
313
+ def __init__(self, magnitude, directions='all', **rng_seed):
314
+ super().__init__(**rng_seed)
315
+ self.magnitude = magnitude
316
+ self.directions = directions.lower().replace(',',' ').replace('-',' ')
317
+
318
+ def __repr__(self):
319
+ return "RandomTilt(%g, '%s')" % (self.magnitude,self.directions)
320
+
321
+ def __call__(self, inp):
322
+ img = F.grab(inp,'img')
323
+ w, h = img.size
324
+
325
+ x1,y1,x2,y2 = 0,0,h,w
326
+ original_plane = [(y1, x1), (y2, x1), (y2, x2), (y1, x2)]
327
+
328
+ max_skew_amount = max(w, h)
329
+ max_skew_amount = int(np.ceil(max_skew_amount * self.magnitude))
330
+ skew_amount = self.rng.integers(1, max_skew_amount)
331
+
332
+ if self.directions == 'all':
333
+ choices = [0,1,2,3]
334
+ else:
335
+ dirs = ['left', 'right', 'up', 'down']
336
+ choices = []
337
+ for d in self.directions.split():
338
+ try:
339
+ choices.append(dirs.index(d))
340
+ except:
341
+ raise ValueError('Tilting direction %s not recognized' % d)
342
+
343
+ skew_direction = self.rng.choice(choices)
344
+
345
+ # print('randomtitlting: ', skew_amount, skew_direction) # to debug random
346
+
347
+ if skew_direction == 0:
348
+ # Left Tilt
349
+ new_plane = [(y1, x1 - skew_amount), # Top Left
350
+ (y2, x1), # Top Right
351
+ (y2, x2), # Bottom Right
352
+ (y1, x2 + skew_amount)] # Bottom Left
353
+ elif skew_direction == 1:
354
+ # Right Tilt
355
+ new_plane = [(y1, x1), # Top Left
356
+ (y2, x1 - skew_amount), # Top Right
357
+ (y2, x2 + skew_amount), # Bottom Right
358
+ (y1, x2)] # Bottom Left
359
+ elif skew_direction == 2:
360
+ # Forward Tilt
361
+ new_plane = [(y1 - skew_amount, x1), # Top Left
362
+ (y2 + skew_amount, x1), # Top Right
363
+ (y2, x2), # Bottom Right
364
+ (y1, x2)] # Bottom Left
365
+ elif skew_direction == 3:
366
+ # Backward Tilt
367
+ new_plane = [(y1, x1), # Top Left
368
+ (y2, x1), # Top Right
369
+ (y2 + skew_amount, x2), # Bottom Right
370
+ (y1 - skew_amount, x2)] # Bottom Left
371
+
372
+ # To calculate the coefficients required by PIL for the perspective skew,
373
+ # see the following Stack Overflow discussion: https://goo.gl/sSgJdj
374
+ homography = F.homography_from_4pts(original_plane, new_plane)
375
+ img = img.transform(img.size, Image.PERSPECTIVE, homography, resample=Image.BICUBIC)
376
+
377
+ homography = np.linalg.pinv(np.float32(homography+(1,)).reshape(3,3))
378
+ return F.update(inp, img=img, homography=homography)
379
+
380
+
381
+ RandomHomography = RandomTilt = RandomTilting # redefinition
382
+
383
+
384
+ class Homography(object):
385
+ """Apply a known tilting to an image
386
+ """
387
+ def __init__(self, *homography):
388
+ assert len(homography) == 8
389
+ self.homography = homography
390
+
391
+ def __call__(self, inp):
392
+ img = F.grab(inp, 'img')
393
+ homography = self.homography
394
+
395
+ img = img.transform(img.size, Image.PERSPECTIVE, homography, resample=Image.BICUBIC)
396
+
397
+ homography = np.linalg.pinv(np.float32(list(homography)+[1]).reshape(3,3))
398
+ return F.update(inp, img=img, homography=homography)
399
+
400
+
401
+
402
+ class StillTransform (DatasetWithRng):
403
+ """ Takes and return an image, without changing its shape or geometry.
404
+ """
405
+ def _transform(self, img):
406
+ raise NotImplementedError()
407
+
408
+ def __call__(self, inp):
409
+ img = F.grab(inp,'img')
410
+
411
+ # transform the image (size should not change)
412
+ try:
413
+ img = self._transform(img)
414
+ except TypeError:
415
+ pass
416
+
417
+ return F.update(inp, img=img)
418
+
419
+
420
+
421
+ class PixelNoise (StillTransform):
422
+ """ Takes an image, and add random white noise.
423
+ """
424
+ def __init__(self, ampl=20, **rng_seed):
425
+ super().__init__(**rng_seed)
426
+ assert 0 <= ampl < 255
427
+ self.ampl = ampl
428
+
429
+ def __repr__(self):
430
+ return "PixelNoise(%g)" % self.ampl
431
+
432
+ def _transform(self, img):
433
+ img = np.float32(img)
434
+ img += self.rng.uniform(0.5-self.ampl/2, 0.5+self.ampl/2, size=img.shape)
435
+ return Image.fromarray(np.uint8(img.clip(0,255)))
436
+
437
+
438
+
439
+ class ColorJitter (StillTransform):
440
+ """Randomly change the brightness, contrast and saturation of an image.
441
+ Copied from https://github.com/pytorch in torchvision/transforms/transforms.py
442
+
443
+ Args:
444
+ brightness (float): How much to jitter brightness. brightness_factor
445
+ is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
446
+ contrast (float): How much to jitter contrast. contrast_factor
447
+ is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
448
+ saturation (float): How much to jitter saturation. saturation_factor
449
+ is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
450
+ hue(float): How much to jitter hue. hue_factor is chosen uniformly from
451
+ [-hue, hue]. Should be >=0 and <= 0.5.
452
+ """
453
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
454
+ self.brightness = brightness
455
+ self.contrast = contrast
456
+ self.saturation = saturation
457
+ self.hue = hue
458
+
459
+ def __repr__(self):
460
+ return "ColorJitter(%g,%g,%g,%g)" % (
461
+ self.brightness, self.contrast, self.saturation, self.hue)
462
+
463
+ def get_params(self, brightness, contrast, saturation, hue):
464
+ """Get a randomized transform to be applied on image.
465
+ Arguments are same as that of __init__.
466
+ Returns:
467
+ Transform which randomly adjusts brightness, contrast and
468
+ saturation in a random order.
469
+ """
470
+ transforms = []
471
+ if brightness > 0:
472
+ brightness_factor = self.rng.uniform(max(0, 1 - brightness), 1 + brightness)
473
+ transforms.append(tvf.Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
474
+
475
+ if contrast > 0:
476
+ contrast_factor = self.rng.uniform(max(0, 1 - contrast), 1 + contrast)
477
+ transforms.append(tvf.Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
478
+
479
+ if saturation > 0:
480
+ saturation_factor = self.rng.uniform(max(0, 1 - saturation), 1 + saturation)
481
+ transforms.append(tvf.Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
482
+
483
+ if hue > 0:
484
+ hue_factor = self.rng.uniform(-hue, hue)
485
+ transforms.append(tvf.Lambda(lambda img: F.adjust_hue(img, hue_factor)))
486
+
487
+ # print('colorjitter: ', brightness_factor, contrast_factor, saturation_factor, hue_factor) # to debug random seed
488
+ self.rng.shuffle(transforms)
489
+ transform = tvf.Compose(transforms)
490
+ return transform
491
+
492
+ def _transform(self, img):
493
+ transform = self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
494
+ return transform(img)
495
+
496
+
497
+ def pil_loader(path, mode='RGB'):
498
+ with warnings.catch_warnings():
499
+ warnings.simplefilter("ignore")
500
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
501
+ with (path if hasattr(path,'read') else open(path, 'rb')) as f:
502
+ img = Image.open(f)
503
+ return img.convert(mode)
504
+
505
+ def torchvision_loader(path, mode='RGB'):
506
+ from torchvision.io import read_file, decode_image, read_image, image
507
+ return read_image(getattr(path,'name',path), mode=getattr(image.ImageReadMode,mode))
508
+
509
+
510
+
511
+ if __name__ == '__main__':
512
+ from matplotlib import pyplot as pl
513
+ import argparse
514
+
515
+ parser = argparse.ArgumentParser("Script to try out and visualize transformations")
516
+ parser.add_argument('--img', type=str, default='imgs/test.png', help='input image')
517
+ parser.add_argument('--trfs', type=str, required=True, help='list of transformations')
518
+ parser.add_argument('--layout', type=int, nargs=2, default=(3,3), help='nb of rows,cols')
519
+ args = parser.parse_args()
520
+
521
+ img = dict(img=pil_loader(args.img))
522
+
523
+ trfs = instanciate_transforms(args.trfs)
524
+
525
+ pl.subplots_adjust(0,0,1,1)
526
+ nr,nc = args.layout
527
+
528
+ while True:
529
+ t0 = now()
530
+ imgs2 = [trfs(img) for _ in range(nr*nc)]
531
+
532
+ for j in range(nr):
533
+ for i in range(nc):
534
+ pl.subplot(nr,nc,i+j*nc+1)
535
+ img2 = img if i==j==0 else imgs2.pop() #trfs(img)
536
+ img2 = img2['img']
537
+ pl.imshow(img2)
538
+ pl.xlabel("%d x %d" % img2.size)
539
+ print(f'Took {now() - t0:.2f} seconds')
540
+ pl.show()
datasets/transforms_tools.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ import numpy as np
7
+ from PIL import Image, ImageOps, ImageEnhance
8
+
9
+
10
+ def grab( data, *fields ):
11
+ ''' Called to extract fields from a dictionary
12
+ '''
13
+ if isinstance(data, dict):
14
+ res = []
15
+ for f in fields:
16
+ res.append( data[f] )
17
+ return res[0] if len(fields) == 1 else tuple(res)
18
+
19
+ else: # or it must be the img directly
20
+ assert fields == ('img',) and isinstance(data, (np.ndarray, Image.Image)), \
21
+ f"data should be an image, not {type(data)}!"
22
+ return data
23
+
24
+
25
+ def update( data, **fields):
26
+ ''' Called to update the img_and_label
27
+ '''
28
+ if isinstance( data, dict):
29
+ if 'homography' in fields and 'homography' in data:
30
+ data['homography'] = fields.pop('homography') @ data['homography']
31
+ data.update(fields)
32
+ if 'img' in fields:
33
+ data['imsize'] = data['img'].size
34
+ return data
35
+
36
+ else: # or it must be the img directly
37
+ return fields['img']
38
+
39
+
40
+ def rand_log_uniform(rng, a, b):
41
+ return np.exp(rng.uniform(np.log(a),np.log(b)))
42
+
43
+
44
+ def translate(tx, ty):
45
+ return np.float32(((1,0,tx),(0,1,ty,),(0,0,1)))
46
+
47
+ def rotate(angle):
48
+ return np.float32(((np.cos(angle),-np.sin(angle),0),(np.sin(angle),np.cos(angle),0),(0,0,1)))
49
+
50
+
51
+ def is_pil_image(img):
52
+ return isinstance(img, Image.Image)
53
+
54
+
55
+ def homography_from_4pts(pts_cur, pts_new):
56
+ "pts_cur and pts_new = 4x2 point array, in [(x,y),...] format"
57
+ matrix = []
58
+ for p1, p2 in zip(pts_new, pts_cur):
59
+ matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
60
+ matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
61
+ A = np.matrix(matrix, dtype=np.float)
62
+ B = np.array(pts_cur).reshape(8)
63
+
64
+ homography = np.dot(np.linalg.pinv(A), B)
65
+ homography = tuple(np.array(homography).reshape(8))
66
+ #print(homography)
67
+ return homography
68
+
69
+
70
+
71
+
datasets/utils.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+ class DatasetWithRng:
11
+ """ Make sure that RNG is distributed properly when torch.dataloader() is used
12
+ """
13
+
14
+ def __init__(self, seed=None):
15
+ self.seed = seed
16
+ self.rng = np.random.default_rng(seed)
17
+ self._rng_children = set()
18
+
19
+ def with_same_rng(self, dataset=None):
20
+ if dataset is not None:
21
+ assert isinstance(dataset, DatasetWithRng) and hasattr(dataset, 'rng'), bb()
22
+ self._rng_children.add( dataset )
23
+
24
+ # update all registered children
25
+ for db in self._rng_children:
26
+ db.rng = self.rng
27
+ db.with_same_rng() # recursive call
28
+ return dataset
29
+
30
+ def init_worker(self, tid):
31
+ if self.seed is None:
32
+ self.rng = np.random.default_rng()
33
+ else:
34
+ self.rng = np.random.default_rng(self.seed + tid)
35
+
36
+
37
+ class WorkerWithRngInit:
38
+ " Dataset inherits from datasets.DatasetWithRng() and has an init_worker() function "
39
+ def __call__(self, tid):
40
+ torch.utils.data.get_worker_info().dataset.init_worker(tid)
41
+
42
+
43
+ def corres_from_homography(homography, W, H, grid=64):
44
+ s = max(1, min(W, H) // grid) # at least `grid` points in smallest dim
45
+ sx, sy = [slice(s//2, l, s) for l in (W, H)]
46
+ grid1 = np.mgrid[sy, sx][::-1].reshape(2,-1).T # (x1,y1) grid
47
+
48
+ grid2 = applyh(homography, grid1)
49
+ scale = np.sqrt(np.abs(np.linalg.det(jacobianh(homography, grid1).T)))
50
+
51
+ corres = np.c_[grid1, grid2, np.ones_like(scale), np.zeros_like(scale), scale]
52
+ return corres
53
+
54
+
55
+ def invh( H ):
56
+ return np.linalg.inv(H)
57
+
58
+
59
+ def applyh(H, p, ncol=2, norm=True):
60
+ """ Apply the homography to a list of 2d points in homogeneous coordinates.
61
+
62
+ H: Homography (...x3x3 matrix/tensor)
63
+ p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
64
+
65
+ Returns an array of projected 2d points.
66
+ """
67
+ if isinstance(H, np.ndarray):
68
+ p = np.asarray(p)
69
+ elif isinstance(H, torch.Tensor):
70
+ p = torch.as_tensor(p, dtype=H.dtype)
71
+
72
+ if p.shape[-1]+1 == H.shape[-1]:
73
+ H = H.swapaxes(-1,-2) # transpose H
74
+ p = p @ H[...,:-1,:] + H[...,-1:,:]
75
+ else:
76
+ p = H @ p.T
77
+ if p.ndim >= 2: p = p.swapaxes(-1,-2)
78
+
79
+ if norm:
80
+ p /= p[...,-1:]
81
+ return p[...,:ncol]
82
+
83
+
84
+ def jacobianh(H, p):
85
+ """ H is an homography that maps: f_H(x,y) --> (f_1, f_2)
86
+ So the Jacobian J_H evaluated at p=(x,y) is a 2x2 matrix
87
+ Output shape = (2, 2, N) = (f_, xy, N)
88
+
89
+ Example of derivative:
90
+ numx a*X + b*Y + c*Z
91
+ since x = ----- = ---------------
92
+ denom u*X + v*Y + w*Z
93
+
94
+ numx' * denom - denom' * numx a*denom - u*numx
95
+ dx/dX = ----------------------------- = ----------------
96
+ denom**2 denom**2
97
+ """
98
+ (a, b, c), (d, e, f), (u, v, w) = H
99
+ numx, numy, denom = applyh(H, p, ncol=3, norm=False).T
100
+
101
+ # column x column x
102
+ J = np.float32(((a*denom - u*numx, b*denom - v*numx), # row f_1
103
+ (d*denom - u*numy, e*denom - v*numy))) # row f_2
104
+ return J / np.where(denom, denom*denom, np.nan)
datasets/web_images.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ import os, os.path as osp
7
+
8
+ from tqdm import trange
9
+ from .image_set import ImageSet, verify_img
10
+
11
+
12
+ class RandomWebImages (ImageSet):
13
+ """ 1 million distractors from Oxford and Paris Revisited
14
+ see http://ptak.felk.cvut.cz/revisitop/revisitop1m/
15
+ """
16
+ def __init__(self, start=0, end=52, root="datasets/revisitop1m"):
17
+ bar = None
18
+ imgs = []
19
+ for i in range(start, end):
20
+ try:
21
+ # read cached list
22
+ img_list_path = osp.join(root, "image_list_%d.txt"%i)
23
+ cached_imgs = [e.strip() for e in open(img_list_path)]
24
+ assert cached_imgs, f"Cache '{img_list_path}' is empty!"
25
+ imgs += cached_imgs
26
+
27
+ except IOError:
28
+ if bar is None:
29
+ bar = trange(start, 4*end, desc='Caching')
30
+ bar.update(4*i)
31
+
32
+ # create it
33
+ imgs = []
34
+ for d in range(i*4,(i+1)*4): # 4096 folders in total, on average 256 each
35
+ key = hex(d)[2:].zfill(3)
36
+ folder = osp.join(root, key)
37
+ if not osp.isdir(folder): continue
38
+ imgs += [f for f in os.listdir(folder) if verify_img(osp.join(folder, f), exts='.jpg')]
39
+ bar.update(1)
40
+ assert imgs, f"No images found in {folder}/"
41
+ open(img_list_path,'w').write('\n'.join(imgs))
42
+ imgs += imgs
43
+
44
+ if bar: bar.update(bar.total - bar.n)
45
+ super().__init__(root, imgs)
46
+
47
+ def get_image_path(self, idx):
48
+ key = self.imgs[idx]
49
+ return osp.join(self.root, key[:3], key)
50
+
demo_warping.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ import os, os.path as osp
7
+
8
+ from PIL import Image
9
+ import numpy as np
10
+ from tools.viz import pl, noticks
11
+
12
+ """ This script will warp (deform) img2 so that it fits img1
13
+
14
+ >> In case of memory failure (not enough GPU memory):
15
+ try adding '--resize 400 300' (or larger values if possible) to the _exec(...) command below.
16
+ """
17
+
18
+ def parse_args():
19
+ import argparse
20
+ parser = argparse.ArgumentParser('PUMP demo script for the image warping demo')
21
+
22
+ parser.add_argument('--img1', default='datasets/demo_warp/mountains_src.jpg')
23
+ parser.add_argument('--img2', default='datasets/demo_warp/mountains_tgt.jpg')
24
+ parser.add_argument('--output', default='results/demo_warp')
25
+
26
+ parser.add_argument('--just-print', action='store_true', help='just print commands')
27
+ return parser.parse_args()
28
+
29
+
30
+ def main( args ):
31
+ run_pump(args) and run_demo_warp(args)
32
+
33
+
34
+ def run_pump(args):
35
+ output_path = osp.join(args.output, args.img1, args.img2+'.corres')
36
+ if osp.isfile(output_path): return True
37
+
38
+ return _exec(f'''python test_singlescale_recursive.py
39
+ --img1 {args.img1}
40
+ --img2 {args.img2}
41
+ --post-filter densify=True
42
+ --output {output_path}''')
43
+
44
+
45
+ def run_demo_warp(args):
46
+ corres_path = osp.join(args.output, args.img1, args.img2+'.corres')
47
+ corres = np.load(corres_path)['corres']
48
+
49
+ img1 = Image.open(args.img1).convert('RGB')
50
+ img2 = Image.open(args.img2).convert('RGB')
51
+
52
+ W, H = img1.size
53
+ warped_img2 = warp_img(np.asarray(img2), corres[:,2:4].reshape(H,W,2))
54
+
55
+ pl.figure('Warping demo')
56
+
57
+ noticks(pl.subplot(211))
58
+ pl.imshow( img2 )
59
+ pl.title('Source image')
60
+
61
+ noticks(pl.subplot(223))
62
+ pl.imshow( img1 )
63
+ pl.title('Target image')
64
+
65
+ noticks(pl.subplot(224))
66
+ pl.imshow( warped_img2 )
67
+ pl.title('Source image warped to match target')
68
+
69
+ pl.tight_layout()
70
+ pl.show(block=True)
71
+
72
+
73
+ def warp_img( img, absolute_flow ):
74
+ H1, W1, TWO = absolute_flow.shape
75
+ H2, W2, THREE = img.shape
76
+ assert TWO == 2 and THREE == 3
77
+
78
+ warp = absolute_flow.round().astype(int)
79
+ invalid = (warp[:,:,0]<0) | (warp[:,:,0]>=W2) | (warp[:,:,1]<0) | (warp[:,:,1]>=H2)
80
+
81
+ warp[:,:,0] = warp[:,:,0].clip(min=0, max=W2-1)
82
+ warp[:,:,1] = warp[:,:,1].clip(min=0, max=H2-1)
83
+ warp = warp[:,:,0] + W2*warp[:,:,1]
84
+
85
+ warped_img = np.asarray(img).reshape(-1,3)[warp].reshape(H1,W1,3)
86
+ return warped_img
87
+
88
+
89
+ def _exec(cmd):
90
+ # strip & remove \n
91
+ cmd = ' '.join(cmd.split())
92
+
93
+ if args.just_print:
94
+ print(cmd)
95
+ return False
96
+ else:
97
+ return os.WEXITSTATUS(os.system(cmd)) == 0
98
+
99
+
100
+ if __name__ == '__main__':
101
+ args = parse_args()
102
+ main( args )
download_training_data.sh ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 3.0
3
+ # Available only for non-commercial use
4
+
5
+ CODE_ROOT=`pwd`
6
+ if [ ! -e datasets ]; then
7
+ echo "Error: missing datasets/ folder"
8
+ echo "First, create a folder that can host (at least) 15 GB of data."
9
+ echo "Then, create a soft-link named 'data' that points to it."
10
+ exit -1
11
+ fi
12
+
13
+ # download some web images from the revisitop1m dataset
14
+ WEB_ROOT=datasets/revisitop1m
15
+ mkdir -p $WEB_ROOT
16
+ cd $WEB_ROOT
17
+ if [ ! -e 0d3 ]; then
18
+ for i in {1..5}; do
19
+ echo "Installing the web images dataset ($i/5)..."
20
+ if [ ! -f revisitop1m.$i.tar.gz ]; then
21
+ wget http://ptak.felk.cvut.cz/revisitop/revisitop1m/jpg/revisitop1m.$i.tar.gz
22
+ fi
23
+ tar -xzvf revisitop1m.$i.tar.gz
24
+ rm -f revisitop1m.$i.tar.gz
25
+ done
26
+ fi
27
+ cd $CODE_ROOT
28
+
29
+ # download SfM120k pairs
30
+ SFM_ROOT=datasets/sfm120k
31
+ mkdir -p $SFM_ROOT
32
+ cd $SFM_ROOT
33
+ if [ ! -e "ims" ]; then
34
+ echo "Downloading the SfM120k dataset..."
35
+ fname=ims.tar.gz
36
+ if [ ! -f $fname ]; then
37
+ wget http://cmp.felk.cvut.cz/cnnimageretrieval/data/train/ims/ims.tar.gz
38
+ fi
39
+ tar -xzvf $fname -C ims
40
+ rm -f $fname
41
+ fi
42
+ if [ ! -e "corres" ]; then
43
+ echo "Installing the SfM120k dataset..."
44
+ fname=corres.tar.gz
45
+ if [ ! -f $meta ]; then
46
+ wget https://download.europe.naverlabs.com/corres.tar.gz
47
+ fi
48
+ tar -xzvf $fname
49
+ rm -f $fname
50
+ fi
51
+ cd $CODE_ROOT
52
+
53
+ echo "Done!"
imgs/demo_warp.jpg ADDED
imgs/overview.png ADDED
imgs/teaser_paper.jpg ADDED
imgs/test.png ADDED
post_filter.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ import pdb, sys, os
6
+ import argparse
7
+ import numpy as np
8
+ from scipy.sparse import coo_matrix, csr_matrix, triu, csgraph
9
+
10
+ import core.functional as myF
11
+ from tools.common import image, image_with_trf
12
+ from tools.viz import dbgfig, show_correspondences
13
+
14
+
15
+ def arg_parser():
16
+ parser = argparse.ArgumentParser("Post-filtering of Deep matching correspondences")
17
+
18
+ parser.add_argument("--img1", required=True, help="path to first image")
19
+ parser.add_argument("--img2", required=True, help="path to second image")
20
+ parser.add_argument("--resize", default=0, type=int, help="prior image downsize (0 if recursive)")
21
+ parser.add_argument("--corres", required=True, help="input path")
22
+ parser.add_argument("--output", default="", help="filtered corres output")
23
+
24
+ parser.add_argument("--locality", type=float, default=2, help="tolerance to deformation")
25
+ parser.add_argument("--min-cc-size", type=int, default=50, help="min connex-component size")
26
+ parser.add_argument("--densify", default='no', choices=['no','full','cc','convex'], help="output pixel-dense corres field")
27
+ parser.add_argument("--dense-side", default='left', choices=['left','right'], help="img to densify")
28
+
29
+ parser.add_argument("--verbose", "-v", type=int, default=0, help="verbosity level")
30
+ parser.add_argument("--dbg", type=str, nargs='+', default=(), help="debug options")
31
+ return parser
32
+
33
+
34
+ def main(args):
35
+ import test_singlescale as pump
36
+ corres = np.load(args.corres)['corres']
37
+ imgs = tuple(map(image, pump.Main.load_images(args)))
38
+
39
+ if dbgfig('raw',args.dbg):
40
+ show_correspondences(*imgs, corres)
41
+
42
+ corres = filter_corres( *imgs, corres,
43
+ locality=args.locality, min_cc_size=args.min_cc_size,
44
+ densify=args.densify, dense_side=args.dense_side,
45
+ verbose=args.verbose, dbg=args.dbg)
46
+
47
+ if dbgfig('viz',args.dbg):
48
+ show_correspondences(*imgs, corres)
49
+
50
+ return pump.save_output( args, corres )
51
+
52
+
53
+ def filter_corres( img0, img1, corres,
54
+ locality = None, # graph edge locality
55
+ min_cc_size = None, # min CC size
56
+ densify = None,
57
+ dense_side = None,
58
+ verbose = 0, dbg=()):
59
+
60
+ if None in (locality, min_cc_size, densify, dense_side):
61
+ default_params = arg_parser()
62
+ locality = locality or default_params.get_default('locality')
63
+ min_cc_size = min_cc_size or default_params.get_default('min_cc_size')
64
+ densify = densify or default_params.get_default('densify')
65
+ dense_side = dense_side or default_params.get_default('dense_side')
66
+
67
+ img0, trf0 = img0 if isinstance(img0,tuple) else (img0, np.eye(3))
68
+ img1, trf1 = img1 if isinstance(img1,tuple) else (img1, np.eye(3))
69
+ assert isinstance(img0, np.ndarray) and isinstance(img1, np.ndarray)
70
+
71
+ corres = myF.affmul((np.linalg.inv(trf0),np.linalg.inv(trf1)), corres)
72
+ n_corres = len(corres)
73
+ if verbose: print(f'>> input: {len(corres)} correspondences')
74
+
75
+ graph = compute_graph(corres, max_dis=locality*4)
76
+ if verbose: print(f'>> {locality=}: {graph.nnz} nodes in graph')
77
+
78
+ cc_sizes = measure_connected_components(graph)
79
+ corres[:,4] += np.log2(cc_sizes)
80
+ corres = corres[cc_sizes > min_cc_size]
81
+ if verbose: print(f'>> {min_cc_size=}: remaining {len(corres)} correspondences')
82
+
83
+ final = myF.affmul((trf0,trf1), corres)
84
+
85
+ if densify != 'no':
86
+ # densify correspondences
87
+ if dense_side == 'right': # temporary swap
88
+ final = final[:,[2,3,0,1]]
89
+ H = round(img1.shape[0] / trf1[1,1])
90
+ W = round(img1.shape[1] / trf1[0,0])
91
+ else:
92
+ H = round(img0.shape[0] / trf0[1,1])
93
+ W = round(img0.shape[1] / trf0[0,0])
94
+
95
+ if densify == 'cc':
96
+ assert False, 'todo'
97
+ elif densify in (True, 'full', 'convex'):
98
+ # recover true image0's shape
99
+ final = densify_corres( final, (H, W), full=(densify!='convex') )
100
+ else:
101
+ raise ValueError(f'Bad mode for {densify=}')
102
+
103
+ if dense_side == 'right': # undo temporary swap
104
+ final = final[:,[2,3,0,1]]
105
+
106
+ return final
107
+
108
+
109
+ def compute_graph(corres, max_dis=10, min_ang=90):
110
+ """ 4D distances (corres can only be connected to same scale)
111
+ using sparse matrices for efficiency
112
+
113
+ step1: build horizontal and vertical binning, binsize = max_dis
114
+ add in each bin all neighbor bins
115
+ step2: for each corres, we can intersect 2 bins to get a short list of candidates
116
+ step3: verify euclidean distance < maxdis (optional?)
117
+ """
118
+ def bin_positions(pos):
119
+ # every corres goes into a single bin
120
+ bin_indices = np.int32(pos.clip(min=0) // max_dis) + 1
121
+ cols = np.arange(len(pos))
122
+
123
+ # add the cell before and the cell after, to handle border effects
124
+ res = csr_matrix((np.ones(len(bin_indices)*3,dtype=np.float32),
125
+ (np.r_[bin_indices-1, bin_indices, bin_indices+1], np.r_[cols,cols,cols])),
126
+ shape=(bin_indices.max()+2 if bin_indices.size else 1, len(pos)))
127
+
128
+ return res, bin_indices
129
+
130
+ # 1-hot matrices of shape = nbins x n_corres
131
+ x1_bins = bin_positions(corres[:,0])
132
+ y1_bins = bin_positions(corres[:,1])
133
+ x2_bins = bin_positions(corres[:,2])
134
+ y2_bins = bin_positions(corres[:,3])
135
+
136
+ def row_indices(ngh):
137
+ res = np.bincount(ngh.indptr[1:-1], minlength=ngh.indptr[-1])[:-1]
138
+ return res.cumsum()
139
+
140
+ def compute_dist( ngh, pts, scale=None ):
141
+ # pos from the second point
142
+ x_pos = pts[ngh.indices,0]
143
+ y_pos = pts[ngh.indices,1]
144
+
145
+ # subtract pos from the 1st point
146
+ rows = row_indices(ngh)
147
+ x_pos -= pts[rows, 0]
148
+ y_pos -= pts[rows, 1]
149
+ dis = np.sqrt(np.square(x_pos) + np.square(y_pos))
150
+ if scale is not None:
151
+ # there is a scale for each of the 2 pts, we encline to choose the worst one
152
+ dis *= (scale[rows] + scale[ngh.indices]) / 2 # so we use arithmetic instead of geometric mean
153
+
154
+ return normed(np.c_[x_pos, y_pos]), dis
155
+
156
+ def Rot( ngh, degrees ):
157
+ rows = row_indices(ngh)
158
+ rad = degrees * np.pi / 180
159
+ rad = (rad[rows] + rad[ngh.indices]) / 2 # average angle between 2 corres
160
+ cos, sin = np.cos(rad), np.sin(rad)
161
+ return np.float32(((cos, -sin), (sin,cos))).transpose(2,0,1)
162
+
163
+ def match(xbins, ybins, pt1, pt2, way):
164
+ xb, ixb = xbins
165
+ yb, iyb = ybins
166
+
167
+ # gets for each corres a list of potential matches
168
+ ngh = xb[ixb].multiply( yb[iyb] ) # shape = n_corres x n_corres
169
+ ngh = triu(ngh, k=1).tocsr() # remove mirrored matches
170
+ # ngh = matches of matches, shape = n_corres x n_corres
171
+
172
+ # verify locality and flow
173
+ vec1, d1 = compute_dist(ngh, pt1) # for each match, distance and orientation in img1
174
+ # assert d1.max()**0.5 < 2*max_dis*1.415, 'cannot be larger than 2 cells in diagonals, or there is a bug'+bb()
175
+ scale, rot = myF.decode_scale_rot(corres[:,5])
176
+ vec2, d2 = compute_dist(ngh, pt2, scale=scale**(-way))
177
+ ang = np.einsum('ik,ik->i', (vec1[:,None] @ Rot(ngh,way*rot))[:,0], vec2)
178
+
179
+ valid = (d1 <= max_dis) & (d2 <= max_dis) & (ang >= np.cos(min_ang*np.pi/180))
180
+ res = csr_matrix((valid, ngh.indices, ngh.indptr), shape=ngh.shape)
181
+ res.eliminate_zeros()
182
+ return res
183
+
184
+ # find all neihbors within each xy bin
185
+ ngh1 = match(x1_bins, y1_bins, corres[:,0:2], corres[:,2:4], way=+1)
186
+ ngh2 = match(x2_bins, y2_bins, corres[:,2:4], corres[:,0:2], way=-1).T
187
+
188
+ return ngh1 + ngh2 # union
189
+
190
+
191
+ def measure_connected_components(graph, dbg=()):
192
+ # compute connected components
193
+ nc, labels = csgraph.connected_components(graph, directed=False)
194
+
195
+ # filter and remove all small components
196
+ count = np.bincount(labels)
197
+
198
+ return count[labels]
199
+
200
+ def normed( mat ):
201
+ return mat / np.linalg.norm(mat, axis=-1, keepdims=True).clip(min=1e-16)
202
+
203
+
204
+ def densify_corres( corres, shape, full=True ):
205
+ from scipy.interpolate import LinearNDInterpolator
206
+ from scipy.spatial import cKDTree as KDTree
207
+
208
+ assert len(corres) > 3, 'Not enough corres for densification'
209
+ H, W = shape
210
+
211
+ interp = LinearNDInterpolator(corres[:,0:2], corres[:,2:4])
212
+ X, Y = np.mgrid[0:H, 0:W][::-1] # H x W, H x W
213
+ p1 = np.c_[X.ravel(), Y.ravel()]
214
+ p2 = interp(X, Y) # H x W x 2
215
+
216
+ p2 = p2.reshape(-1,2)
217
+ invalid = np.isnan(p2).any(axis=1)
218
+
219
+ if full:
220
+ # interpolate pixels outside of the convex hull
221
+ badp = p1[invalid]
222
+ tree = KDTree(corres[:,0:2])
223
+ _, nn = tree.query(badp, 3) # find 3 closest neighbors
224
+ corflow = corres[:,2:4] - corres[:,0:2]
225
+ p2.reshape(-1,2)[invalid] = corflow[nn].mean(axis=1) + p1[invalid]
226
+ else:
227
+ # remove nans, i.e. remove points outside of convex hull
228
+ p1, p2 = p1[~invalid], p2[~invalid]
229
+
230
+ # return correspondence field
231
+ return np.c_[p1, p2]
232
+
233
+
234
+ if __name__ == '__main__':
235
+ main(arg_parser().parse_args())
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ matplotlib
2
+ numpy
3
+ scipy
4
+ torch==1.11.0
5
+ torchvision==0.12.0
run_ETH3D.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ import os, os.path as osp
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+
10
+ SEQUENCES = [ 'lakeside', 'sand_box', 'storage_room', 'storage_room_2', 'tunnel',
11
+ 'delivery_area', 'electro', 'forest', 'playground', 'terrains']
12
+
13
+ RATES = [3, 5, 7, 9, 11, 13, 15]
14
+
15
+ def parse_args():
16
+ import argparse
17
+ parser = argparse.ArgumentParser('PUMP evaluation script for the ETH3D dataset')
18
+
19
+ parser.add_argument('--root', default='datasets/eth3d')
20
+ parser.add_argument('--output', default='results/eth3d')
21
+
22
+ parser.add_argument('--just-print', action='store_true', help='just print commands')
23
+ return parser.parse_args()
24
+
25
+
26
+ def main( args ):
27
+ run_pump(args) and run_eval(args)
28
+
29
+
30
+ def run_pump(args):
31
+ done = True
32
+ for img1, img2 in tqdm(list_eth3d_pairs()):
33
+ output_path = osp.join(args.output, img1, img2+'.corres')
34
+ if osp.isfile(output_path): continue
35
+
36
+ done = False
37
+ _exec(f'''python test_multiscale_recursive.py
38
+ --img1 {osp.join(args.root,img1)}
39
+ --img2 {osp.join(args.root,img2)}
40
+ --max-scale 1.5
41
+ --desc PUMP
42
+ --post-filter "densify=True,dense_side='right'"
43
+ --output {output_path}''')
44
+
45
+ return done
46
+
47
+
48
+ def run_eval( args ):
49
+ for rate in RATES:
50
+ mean_aepe_per_rate = 0
51
+
52
+ for seq in SEQUENCES:
53
+ pairs = np.load(osp.join(args.root, 'info_ETH3D_files', f'{seq}_every_5_rate_of_{rate}'), allow_pickle=True)
54
+
55
+ mean_aepe_per_seq = 0
56
+ for pair in pairs:
57
+ img1, img2 = pair['source_image'], pair['target_image']
58
+ Ys, Xs, Yt, Xt = [np.float32(pair[k]) for k in 'Ys Xs Yt Xt'.split()]
59
+
60
+ corres_path = osp.join(args.output, img1, img2+'.corres')
61
+ corres = np.load(corres_path, allow_pickle=True)['corres']
62
+
63
+ # extract estimated and target flow
64
+ W, H = np.int32(corres[-1, 2:4] + 1)
65
+ flow = (corres[:,0:2] - corres[:,2:4]).reshape(H, W, 2)
66
+ iYt, iXt = np.int32(np.round(Yt)), np.int32(np.round(Xt))
67
+ if 'correct way':
68
+ gt_targets = np.c_[Xs - Xt, Ys - Yt]
69
+ est_targets = flow[iYt, iXt]
70
+ elif 'GLU-Net way (somewhat inaccurate because of overlapping points in the mask)':
71
+ mask = np.zeros((H,W), dtype=bool)
72
+ mask[iYt, iXt] = True
73
+ gt_flow = np.full((H,W,2), np.nan, dtype=np.float32)
74
+ gt_flow[iYt, iXt, 0] = Xs - Xt
75
+ gt_flow[iYt, iXt, 1] = Ys - Yt
76
+ gt_targets = gt_flow[mask]
77
+ est_targets = flow[mask]
78
+
79
+ # compute end-point error
80
+ aepe = np.linalg.norm(est_targets - gt_targets, axis=-1).mean()
81
+ mean_aepe_per_seq += aepe
82
+
83
+ mean_aepe_per_seq /= len(pairs)
84
+ mean_aepe_per_rate += mean_aepe_per_seq
85
+ print(f'mean AEPE for {rate=} {seq=}:', mean_aepe_per_seq)
86
+
87
+ print(f'>> mean AEPE for {rate=}:', mean_aepe_per_rate / len(SEQUENCES))
88
+
89
+
90
+ def list_eth3d_pairs():
91
+ path = osp.join(args.root, 'info_ETH3D_files', 'list_pairs.txt')
92
+ try:
93
+ lines = open(path).read().splitlines()
94
+ except OSError:
95
+ lines = []
96
+ for seq in SEQUENCES:
97
+ for rate in RATES:
98
+ pairs = np.load(osp.join(args.root, 'info_ETH3D_files', f'{seq}_every_5_rate_of_{rate}'), allow_pickle=True)
99
+ for pair in pairs:
100
+ lines.append(pair['source_image'] + ' ' + pair['target_image'])
101
+ open(path, 'w').write('\n'.join(lines))
102
+
103
+ pairs = [line.split() for line in lines if line[0] != '#']
104
+ return pairs
105
+
106
+
107
+ def _exec(cmd):
108
+ # strip & remove \n
109
+ cmd = ' '.join(cmd.split())
110
+ if args.just_print:
111
+ print(cmd)
112
+ else:
113
+ os.system(cmd)
114
+
115
+
116
+ if __name__ == '__main__':
117
+ args = parse_args()
118
+ main( args )
test_multiscale.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ from itertools import starmap
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ import test_singlescale as tss
13
+ from core import functional as myF
14
+ from tools.common import todevice, cpu
15
+ from tools.viz import dbgfig, show_correspondences
16
+
17
+
18
+ def arg_parser():
19
+ parser = tss.arg_parser()
20
+ parser.set_defaults(levels = 0, verbose=0)
21
+
22
+ parser.add_argument('--min-scale', type=float, default=None, help='min scale ratio')
23
+ parser.add_argument('--max-scale', type=float, default=4, help='max scale ratio')
24
+
25
+ parser.add_argument('--min-rot', type=float, default=None, help='min rotation (in degrees) in [-180,180]')
26
+ parser.add_argument('--max-rot', type=float, default=0, help='max rotation (in degrees) in [0,180]')
27
+ parser.add_argument('--crop-rot', action='store_true', help='crop rotated image to prevent memory blow-up')
28
+ parser.add_argument('--rot-step', type=int, default=45, help='rotation step (in degrees)')
29
+
30
+ parser.add_argument('--no-swap', type=int, default=1, nargs='?', const=0, choices=[1,0,-1], help='if 0, img1 will have keypoints on a grid')
31
+ parser.add_argument('--same-levels', action='store_true', help='use the same number of pyramid levels for all scales')
32
+
33
+ parser.add_argument('--merge', choices='torch cpu cuda'.split(), default='cpu')
34
+ return parser
35
+
36
+
37
+ class MultiScalePUMP (nn.Module):
38
+ """ DeepMatching that loops over all possible {scale x rotation} combinations.
39
+ """
40
+ def __init__(self, matcher,
41
+ min_scale=1,
42
+ max_scale=1,
43
+ max_rot=0,
44
+ min_rot=0,
45
+ rot_step=45,
46
+ swap_mode=1,
47
+ same_levels=False,
48
+ crop_rot=False):
49
+ super().__init__()
50
+ min_scale = min_scale or 1/max_scale
51
+ min_rot = min_rot or -max_rot
52
+ assert 0.1 <= min_scale <= max_scale <= 10
53
+ assert -180 <= min_rot <= max_rot <= 180
54
+ self.matcher = matcher
55
+ self.matcher.crop_rot = crop_rot
56
+
57
+ self.min_sc = min_scale
58
+ self.max_sc = max_scale
59
+ self.min_rot = min_rot
60
+ self.max_rot = max_rot
61
+ self.rot_step = rot_step
62
+ self.swap_mode = swap_mode
63
+ self.merge_device = None
64
+ self.same_levels = same_levels
65
+
66
+ @torch.no_grad()
67
+ def forward(self, img1, img2, dbg=()):
68
+ img1, sca1 = img1 if isinstance(img1, tuple) else (img1, torch.eye(3, device=img1.device))
69
+ img2, sca2 = img2 if isinstance(img2, tuple) else (img2, torch.eye(3, device=img2.device))
70
+
71
+ # prepare correspondences accumulators
72
+ if self.same_levels: # limit number of levels
73
+ self.matcher.levels = self._find_max_levels(img1,img2)
74
+ elif self.matcher.levels == 0:
75
+ max_psize = int(min(np.mean(img1.shape[-2:]), np.mean(img2.shape[-2:])))
76
+ self.matcher.levels = int(np.log2(max_psize / self.matcher.pixel_desc.get_atomic_patch_size()))
77
+
78
+ all_corres = (self._make_accu(img1), self._make_accu(img2))
79
+
80
+ for scale, ang, code, swap, swapped, (scimg1, scimg2) in self._enum_scaled_pairs(img1, img2):
81
+ print(f"processing {scale=:g} x {ang=} {['','(swapped)'][swapped]} ({code=})...")
82
+
83
+ # compute correspondences with rotated+scaled image
84
+ corres, rots = self.process_one_scale(swapped, *[scimg1,scimg2], dbg=dbg)
85
+ if dbgfig('corres-ms', dbg): viz_correspondences(img1, img2, *corres, fig='last')
86
+
87
+ # merge correspondences in the reference frame
88
+ self.merge_corres( corres, rots, all_corres, code )
89
+
90
+ # final intersection
91
+ corres = self.reciprocal( *all_corres )
92
+ return myF.affmul(todevice((sca1,sca2),corres.device), corres) # rescaling to original image scale
93
+
94
+ def process_one_scale(self, swapped, *imgs, dbg=()):
95
+ return unswap(self.matcher(*imgs, ret='raw', dbg=dbg), swapped)
96
+
97
+ def _find_max_levels(self, img1, img2):
98
+ min_levels = self.matcher.levels or 999
99
+ for _, _, code, _, _, (img1, img2) in self._enum_scaled_pairs(img1, img2):
100
+ # first level when a parent dont have children: gap >= min(shape), with gap = 2**(level-2)
101
+ img1_levels = ceil(np.log2(min(img1[0].shape[-2:])) - 1)
102
+ # first level when img2's shape becomes smaller than self.min_shape, with shape = min(shape) / 2**level
103
+ img2_levels = ceil(np.log2(min(img2[0].shape[-2:]) / self.matcher.min_shape))
104
+ # print(f'predicted levels for {code=}:\timg1 --> {img1_levels},\timg2 --> {img2_levels} levels')
105
+ min_levels = min(min_levels, img1_levels, img2_levels)
106
+ return min_levels
107
+
108
+ def merge_corres(self, corres, rots, all_corres, code):
109
+ " rot : reference --> rotated "
110
+ self.merge_one_side( corres[0], slice(0,2), rots[0], all_corres[0], code )
111
+ self.merge_one_side( corres[1], slice(2,4), rots[1], all_corres[1], code )
112
+
113
+ def merge_one_side(self, corres, sel, trf, all_corres, code ):
114
+ pos, scores = corres
115
+ grid, accu = all_corres
116
+ accu = accu.view(-1, 6)
117
+
118
+ # compute 4-nn in transformed image for each grid point
119
+ best4 = torch.cdist(pos[:,sel].float(), grid).topk(4, dim=0, largest=False)
120
+ # best4.shape = (4, len(grid))
121
+
122
+ # update if score is better AND distance less than 2x best dist
123
+ scale = float(torch.sqrt(torch.det(trf))) # == scale (with scale >= 1)
124
+ dist_max = 8*scale - 1e-7 # 2x the distance between contiguous patches
125
+
126
+ close_enough = (best4.values <= 2*best4.values[0:1]) & (best4.values < dist_max)
127
+ neg_inf = torch.tensor(-np.inf, device=scores.device)
128
+ best_score = torch.where(close_enough, scores.ravel()[best4.indices], neg_inf).max(dim=0)
129
+ is_better = best_score.values > accu[:,4].ravel()
130
+
131
+ accu[is_better,0:4] = pos[best4.indices[best_score.indices,torch.arange(len(grid))][is_better]]
132
+ accu[is_better,4] = best_score.values[is_better]
133
+ accu[is_better,5] = code
134
+
135
+ def reciprocal(self, corres1, corres2 ):
136
+ grid1, corres1 = cpu(corres1)
137
+ grid2, corres2 = cpu(corres2)
138
+
139
+ (H1, W1), (H2, W2) = grid1[-1]+1, grid2[-1]+1
140
+ pos1 = corres1[:,:,0:4].view(-1,4)
141
+ pos2 = corres2[:,:,0:4].view(-1,4)
142
+
143
+ to_int = torch.tensor((W1*H2*W2, H2*W2, W2, 1), dtype=torch.float32)
144
+ inter1 = myF.intersection(pos1@to_int, pos2@to_int)
145
+ return corres1.view(-1,6)[inter1]
146
+
147
+ def _enum_scales(self):
148
+ for i in range(-100,101):
149
+ scale = 2**(i/2)
150
+ # if i != -2: continue
151
+ if self.min_sc <= scale <= self.max_sc:
152
+ yield i,scale
153
+
154
+ def _enum_rotations(self):
155
+ for i in range(-180//self.rot_step, 180//self.rot_step):
156
+ rot = i * self.rot_step
157
+ if self.min_rot <= rot <= self.max_rot:
158
+ yield i,-rot
159
+
160
+ def _enum_scaled_pairs(self, img1, img2):
161
+ for s, scale in self._enum_scales():
162
+ (i1,sca1), (i2,sca2) = starmap(downsample_img, [(img1, min(scale, 1)), (img2, min(1/scale, 1))])
163
+ # set bigger image as the first one
164
+ size1 = min(i1.shape[-2:])
165
+ size2 = min(i2.shape[-2:])
166
+ swapped = size1*self.swap_mode < size2*self.swap_mode
167
+ swap = (1 - 2*swapped) # swapped ==> swap = -1
168
+ if swapped:
169
+ (i1,sca1), (i2,sca2) = (i2,sca2), (i1,sca1)
170
+
171
+ for r, ang in self._enum_rotations():
172
+ code = myF.encode_scale_rot(scale, ang)
173
+ trf1 = (sca1, swap*ang) if ang != 0 else sca1
174
+ yield scale, ang, code, swap, swapped, ((i1,trf1), (i2,sca2))
175
+
176
+ def _make_accu(self, img):
177
+ C, H, W = img.shape
178
+ step = self.matcher.pixel_desc.get_atomic_patch_size() // 2
179
+ h = step//2 - 1
180
+ accu = img.new_zeros(((H+h)//step, (W+h)//step, 6), dtype=torch.float32, device=self.merge_device or img.device)
181
+ grid = step * myF.mgrid(accu[:,:,0], device=img.device) + (step//2)
182
+ return grid, accu
183
+
184
+
185
+ def downsample_img(img, scale=0):
186
+ assert scale <= 1
187
+ img, trf = img if isinstance(img, tuple) else (img, torch.eye(3, device=img.device))
188
+ if scale == 1: return img, trf
189
+
190
+ assert img.dtype == torch.uint8
191
+ trf = trf.clone() # dont modify inplace
192
+ trf[:2,:2] /= scale
193
+ while scale <= 0.5:
194
+ img = F.avg_pool2d(img[None].float(), 2, stride=2, count_include_pad=False)[0]
195
+ scale *= 2
196
+ if scale != 1:
197
+ img = F.interpolate(img[None].float(), scale_factor=scale, mode='bicubic', align_corners=False, recompute_scale_factor=False).clamp(min=0, max=255)[0]
198
+ return img.byte(), trf # scaled --> pxl
199
+
200
+
201
+ def ceil(i):
202
+ return int(np.ceil(i))
203
+
204
+ def unswap( corres, swapped ):
205
+ swap = -1 if swapped else 1
206
+ corres, rots = corres
207
+ corres = corres[::swap]
208
+ rots = rots[::swap]
209
+ if swapped:
210
+ for pos, _ in corres:
211
+ pos[:,0:4] = pos[:,[2,3,0,1]].clone()
212
+ return corres, rots
213
+
214
+
215
+ def demultiplex_img_trf(self, img, force=False):
216
+ """ img is:
217
+ - an image
218
+ - a tuple (image, trf)
219
+ - a tuple (image, (cur_trf, trf_todo))
220
+ In any case, trf: cur_pix --> old_pix
221
+ """
222
+ img, trf = img if isinstance(img, tuple) else (img, torch.eye(3, device=img.device))
223
+
224
+ if isinstance(trf, tuple):
225
+ trf, todo = trf
226
+ if isinstance(todo, (int,float)): # pure rotation
227
+ img, trf = myF.rotate_img((img,trf), angle=todo, crop=self.crop_rot)
228
+ else:
229
+ img = myF.apply_trf_to_img(todo, img)
230
+ trf = trf @ todo
231
+ return img, trf
232
+
233
+
234
+ class Main (tss.Main):
235
+ @staticmethod
236
+ def get_options( args ):
237
+ return dict(max_scale=args.max_scale, min_scale=args.min_scale,
238
+ max_rot=args.max_rot, min_rot=args.min_rot, rot_step=args.rot_step,
239
+ swap_mode=args.no_swap, same_levels=args.same_levels, crop_rot=args.crop_rot)
240
+
241
+ @staticmethod
242
+ def tune_matcher( args, matcher, device ):
243
+ if device == 'cpu':
244
+ args.merge = 'cpu'
245
+
246
+ if args.merge == 'cpu': type(matcher).merge_corres = myF.merge_corres; matcher.merge_device = 'cpu'
247
+ elif args.merge == 'cuda': type(matcher).merge_corres = myF.merge_corres
248
+
249
+ return matcher.to(device)
250
+
251
+ @staticmethod
252
+ def build_matcher( args, device):
253
+ # get a normal matcher
254
+ matcher = tss.Main.build_matcher(args, device)
255
+ type(matcher).demultiplex_img_trf = demultiplex_img_trf # update transformer
256
+
257
+ options = Main.get_options(args)
258
+ return Main.tune_matcher(args, MultiScalePUMP(matcher, **options), device)
259
+
260
+
261
+ if __name__ == '__main__':
262
+ Main().run_from_args(arg_parser().parse_args())
test_multiscale_recursive.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ import test_singlescale as ss
6
+ import test_singlescale_recursive as ssr
7
+ import test_multiscale as ms
8
+
9
+ def arg_parser():
10
+ parser = ssr.arg_parser(ms.arg_parser())
11
+ return parser
12
+
13
+ class Main (ms.Main):
14
+ @staticmethod
15
+ def build_matcher(args, device):
16
+ # get a single-scale recursive matcher
17
+ matcher = ssr.Main.build_matcher(args, device)
18
+ type(matcher).demultiplex_img_trf = ms.demultiplex_img_trf # update transformer
19
+
20
+ options = Main.get_options(args)
21
+ return Main.tune_matcher(args, ms.MultiScalePUMP(matcher, **options), device).to(device)
22
+
23
+ if __name__ == '__main__':
24
+ Main().run_from_args(arg_parser().parse_args())
test_singlescale.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from core import functional as myF
12
+ from core.pixel_desc import PixelDesc
13
+ from tools.common import mkdir_for, todevice, cudnn_benchmark, nparray, image, image_with_trf
14
+ from tools.viz import dbgfig, show_correspondences
15
+
16
+
17
+ def arg_parser():
18
+ import argparse
19
+ parser = argparse.ArgumentParser('SingleScalePUMP on GPU with PyTorch')
20
+
21
+ parser.add_argument('--img1', required=True, help='path to img1')
22
+ parser.add_argument('--img2', required=True, help='path to img2')
23
+ parser.add_argument('--resize', type=int, default=512, nargs='+', help='prior downsize of img1 and img2')
24
+
25
+ parser.add_argument('--output', default=None, help='output path for correspondences')
26
+
27
+ parser.add_argument('--levels', type=int, default=99, help='number of pyramid levels')
28
+ parser.add_argument('--min-shape', type=int, default=5, help='minimum size of corr maps')
29
+ parser.add_argument('--nlpow', type=float, default=1.5, help='non-linear activation power in [1,2]')
30
+ parser.add_argument('--border', type=float, default=0.9, help='border invariance level in [0,1]')
31
+ parser.add_argument('--dtype', default='float16', choices='float16 float32 float64'.split())
32
+
33
+ parser.add_argument('--desc', default='PUMP-stytrf', help='checkpoint name')
34
+ parser.add_argument('--first-level', choices='torch'.split(), default='torch')
35
+ parser.add_argument('--activation', choices='torch'.split(), default='torch')
36
+ parser.add_argument('--forward', choices='torch cuda cuda-lowmem'.split(), default='cuda-lowmem')
37
+ parser.add_argument('--backward', choices='python torch cuda'.split(), default='cuda')
38
+ parser.add_argument('--reciprocal', choices='cpu cuda'.split(), default='cpu')
39
+
40
+ parser.add_argument('--post-filter', default=None, const=True, nargs='?', help='post-filtering (See post_filter.py)')
41
+
42
+ parser.add_argument('--verbose', type=int, default=0, help='verbosity')
43
+ parser.add_argument('--device', default='cuda', help='gpu device')
44
+ parser.add_argument('--dbg', nargs='*', default=(), help='debug options')
45
+
46
+ return parser
47
+
48
+
49
+ class SingleScalePUMP (nn.Module):
50
+ def __init__(self, levels = 9, nlpow = 1.4, cutoff = 1,
51
+ border_inv=0.9, min_shape=5, renorm=(),
52
+ pixel_desc = None, dtype = torch.float32,
53
+ verbose = True ):
54
+ super().__init__()
55
+ self.levels = levels
56
+ self.min_shape = min_shape
57
+ self.nlpow = nlpow
58
+ self.border_inv = border_inv
59
+ assert pixel_desc, 'Requires a pixel descriptor'
60
+ self.pixel_desc = pixel_desc.configure(self)
61
+ self.dtype = dtype
62
+ self.verbose = verbose
63
+
64
+ @torch.no_grad()
65
+ def forward(self, img1, img2, ret='corres', dbg=()):
66
+ with cudnn_benchmark(False):
67
+ # compute descriptors
68
+ (img1, img2), pixel_descs, trfs = self.extract_descs(img1, img2, dtype=self.dtype)
69
+
70
+ # backward and forward passes
71
+ pixel_corr = self.first_level(*pixel_descs, dbg=dbg)
72
+ pixel_corr = self.backward_pass(self.forward_pass(pixel_corr, dbg=dbg), dbg=dbg)
73
+
74
+ # recover correspondences
75
+ corres = myF.best_correspondences( pixel_corr )
76
+
77
+ if dbgfig('corres', dbg): viz_correspondences(img1[0], img2[0], *corres, fig='last')
78
+ corres = [(myF.affmul(trfs,pos),score) for pos, score in corres] # rectify scaling etc.
79
+ if ret == 'raw': return corres, trfs
80
+ return self.reciprocal(*corres)
81
+
82
+ def extract_descs(self, img1, img2, dtype=None):
83
+ img1, sca1 = self.demultiplex_img_trf(img1)
84
+ img2, sca2 = self.demultiplex_img_trf(img2)
85
+ desc1, trf1 = self.pixel_desc(img1)
86
+ desc2, trf2 = self.pixel_desc(img2)
87
+ return (img1, img2), (desc1.type(dtype), desc2.type(dtype)), (sca1@trf1, sca2@trf2)
88
+
89
+ def demultiplex_img_trf(self, img, **kw):
90
+ return img if isinstance(img, tuple) else (img, torch.eye(3, device=img.device))
91
+
92
+ def forward_pass(self, pixel_corr, dbg=()):
93
+ weights = None
94
+ if isinstance(pixel_corr, tuple):
95
+ pixel_corr, weights = pixel_corr
96
+
97
+ # first-level with activation
98
+ if self.verbose: print(f' Pyramid level {0} shape={tuple(pixel_corr.shape)}')
99
+ pyramid = [ self.activation(0,pixel_corr) ]
100
+ if dbgfig(f'corr0', dbg): viz_correlation_maps(*from_stack('img1','img2'), pyramid[0], fig='last')
101
+
102
+ for level in range(1, self.levels+1):
103
+ upper, weights = self.forward_level(level, pyramid[-1], weights)
104
+ if weights.sum() == 0: break # img1 has become too small
105
+
106
+ # activation
107
+ pyramid.append( self.activation(level,upper) )
108
+
109
+ if self.verbose: print(f' Pyramid level {level} shape={tuple(upper.shape)}')
110
+ if dbgfig(f'corr{level}', dbg): viz_correlation_maps(*from_stack('img1','img2'), upper, level=level, fig='last')
111
+ if min(upper.shape[-2:]) <= self.min_shape: break # img2 has become too small
112
+
113
+ return pyramid
114
+
115
+ def forward_level(self, level, corr, weights):
116
+ # max-pooling
117
+ pooled = F.max_pool2d(corr, 3, padding=1, stride=2)
118
+
119
+ # sparse conv
120
+ return myF.sparse_conv(level, pooled, weights, norm=self.border_inv)
121
+
122
+ def backward_pass(self, pyramid, dbg=()):
123
+ # same than forward in reverse order
124
+ for level in range(len(pyramid)-1, 0, -1):
125
+ lower = self.backward_level(level, pyramid)
126
+ # assert not torch.isnan(lower).any(), bb()
127
+ if self.verbose: print(f' Pyramid level {level-1} shape={tuple(lower.shape)}')
128
+ del pyramid[-1] # free memory
129
+ if dbgfig(f'corr{level}-bw', dbg): viz_correlation_maps(img1, img2, lower, fig='last')
130
+ return pyramid[0]
131
+
132
+ def backward_level(self, level, pyramid):
133
+ # reverse sparse-coonv
134
+ pooled = myF.sparse_conv(level, pyramid[level], reverse=True)
135
+
136
+ # reverse max-pool and add to lower level
137
+ return myF.max_unpool(pooled, pyramid[level-1])
138
+
139
+ def activation(self, level, corr):
140
+ assert 1 <= self.nlpow <= 3
141
+ corr.clamp_(min=0).pow_(self.nlpow)
142
+ return corr
143
+
144
+ def first_level(self, desc1, desc2, dbg=()):
145
+ assert desc1.ndim == desc2.ndim == 4
146
+ assert len(desc1) == len(desc2) == 1, "not implemented"
147
+ H1, W1 = desc1.shape[-2:]
148
+ H2, W2 = desc2.shape[-2:]
149
+
150
+ patches = F.unfold(desc1, 4, stride=4) # C*4*4, H1*W1//16
151
+ B, C, N = patches.shape
152
+ # rearrange(patches, 'B (C Kh Kw) H1W1 -> B H1W1 C Kh Kw', Kh=4, Kw=4)
153
+ patches = patches.permute(0, 2, 1).view(B, H1W1, C//16, 4, 4)
154
+
155
+ corr, norms = myF.normalized_corr(patches[0], desc2[0], ret_norms=True)
156
+ if dbgfig('ncc',dbg):
157
+ for j in range(0,len(corr),9):
158
+ for i in range(9):
159
+ pl.subplot(3,3,i+1).cla()
160
+ i += j
161
+ pl.imshow(corr[i], vmin=0.9, vmax=1)
162
+ pl.plot(2+(i%16)*4, 2+(i//16)*4,'xr', ms=10)
163
+ bb()
164
+ return corr.view(H1//4, W1//4, H2+1, W2+1), (norms.view(H1//4, W1//4)>0).float()
165
+
166
+ def reciprocal(self, corres1, corres2 ):
167
+ corres1, corres2 = todevice(corres1, 'cpu'), todevice(corres2, 'cpu')
168
+ return myF.reciprocal(self, corres1, corres2)
169
+
170
+
171
+ class Main:
172
+ def __init__(self):
173
+ self.post_filtering = False
174
+
175
+ def run_from_args(self, args):
176
+ device = args.device
177
+ self.matcher = self.build_matcher(args, device)
178
+ if args.post_filter:
179
+ self.post_filtering = {} if args.post_filter is True else eval(f'dict({args.post_filter})')
180
+
181
+ corres = self(*self.load_images(args, device), dbg=set(args.dbg))
182
+
183
+ if args.output:
184
+ self.save_output( args.output, corres )
185
+
186
+ def run_from_args_with_images(self, img1, img2, args):
187
+ device = args.device
188
+ self.matcher = self.build_matcher(args, device)
189
+ if args.post_filter:
190
+ self.post_filtering = {} if args.post_filter is True else eval(f'dict({args.post_filter})')
191
+
192
+ if isinstance(args.resize, int): # user can provide 2 separate sizes for each image
193
+ args.resize = (args.resize, args.resize)
194
+
195
+ if len(args.resize) == 1:
196
+ args.resize = 2 * args.resize
197
+
198
+ images = []
199
+ for imgx, size in zip([img1, img2], args.resize):
200
+ img = torch.from_numpy(np.array(imgx.convert('RGB'))).permute(2,0,1).to(device)
201
+ img = myF.imresize(img, size)
202
+ images.append( img )
203
+
204
+ corres = self(*images, dbg=set(args.dbg))
205
+
206
+ if args.output:
207
+ self.save_output( args.output, corres )
208
+
209
+ return corres
210
+
211
+
212
+ @staticmethod
213
+ def get_options( args ):
214
+ # configure the pipeline
215
+ pixel_desc = PixelDesc(path=f'checkpoints/{args.desc}.pt')
216
+ return dict(levels=args.levels, min_shape=args.min_shape, border_inv=args.border, nlpow=args.nlpow,
217
+ pixel_desc=pixel_desc, dtype=eval(f'torch.{args.dtype}'), verbose=args.verbose)
218
+
219
+ @staticmethod
220
+ def tune_matcher( args, matcher, device ):
221
+ if device == 'cpu':
222
+ matcher.dtype = torch.float32
223
+ args.forward = 'torch'
224
+ args.backward = 'torch'
225
+ args.reciprocal = 'cpu'
226
+
227
+ if args.forward == 'cuda': type(matcher).forward_level = myF.forward_cuda
228
+ if args.forward == 'cuda-lowmem':type(matcher).forward_level = myF.forward_cuda_lowmem
229
+ if args.backward == 'python': type(matcher).backward_pass = legacy.backward_python
230
+ if args.backward == 'cuda': type(matcher).backward_level = myF.backward_cuda
231
+ if args.reciprocal == 'cuda': type(matcher).reciprocal = myF.reciprocal
232
+
233
+ return matcher.to(device)
234
+
235
+ @staticmethod
236
+ def build_matcher(args, device):
237
+ options = Main.get_options(args)
238
+ matcher = SingleScalePUMP(**options)
239
+ return Main.tune_matcher(args, matcher, device)
240
+
241
+ def __call__(self, *imgs, dbg=()):
242
+ corres = self.matcher( *imgs, dbg=dbg).cpu().numpy()
243
+ if self.post_filtering is not False:
244
+ corres = self.post_filter( imgs, corres )
245
+
246
+ if 'print' in dbg: print(corres)
247
+ if dbgfig('viz',dbg): show_correspondences(*imgs, corres)
248
+ return corres
249
+
250
+ @staticmethod
251
+ def load_images( args, device='cpu' ):
252
+ def read_image(impath):
253
+ try:
254
+ from torchvision.io.image import read_image, ImageReadMode
255
+ return read_image(impath, mode=ImageReadMode.RGB)
256
+ except RuntimeError:
257
+ from PIL import Image
258
+ return torch.from_numpy(np.array(Image.open(impath).convert('RGB'))).permute(2,0,1)
259
+
260
+ if isinstance(args.resize, int): # user can provide 2 separate sizes for each image
261
+ args.resize = (args.resize, args.resize)
262
+
263
+ if len(args.resize) == 1:
264
+ args.resize = 2 * args.resize
265
+
266
+ images = []
267
+ for impath, size in zip([args.img1, args.img2], args.resize):
268
+ img = read_image(impath).to(device)
269
+ img = myF.imresize(img, size)
270
+ images.append( img )
271
+ return images
272
+
273
+ def post_filter(self, imgs, corres ):
274
+ from post_filter import filter_corres
275
+ return filter_corres(*map(image_with_trf,imgs), corres, **self.post_filtering)
276
+
277
+ def save_output(self, output_path, corres ):
278
+ mkdir_for( output_path )
279
+ np.savez(open(output_path,'wb'), corres=corres)
280
+
281
+
282
+
283
+ if __name__ == '__main__':
284
+ Main().run_from_args(arg_parser().parse_args())
test_singlescale_recursive.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ import torch
9
+
10
+ import test_singlescale as tss
11
+ import core.functional as myF
12
+ from tools.viz import dbgfig, show_correspondences
13
+
14
+
15
+ def arg_parser(parser = None):
16
+ parser = parser or tss.arg_parser()
17
+
18
+ parser.add_argument('--rec-overlap', type=float, default=0.5, help='overlap between tiles in [0,0.5]')
19
+ parser.add_argument('--rec-score-thr', type=float, default=1, help='corres score threshold to guide fine levels')
20
+ parser.add_argument('--rec-fast-thr', type=float, default=0.1, help='prune block if less than `fast` corres fall in it')
21
+
22
+ return parser
23
+
24
+
25
+ class RecursivePUMP (tss.SingleScalePUMP):
26
+ """ Recursive PUMP:
27
+ 1) find initial correspondences at a coarse scale,
28
+ 2) refine them at a selection of finer scales
29
+ """
30
+ def __init__(self, coarse_size=512, fine_size=512, rec_overlap=0.5, rec_score_thr=1.0,
31
+ rec_fast_thr = 0.1, **other_options ):
32
+ super().__init__(**other_options)
33
+ assert 10 < coarse_size < 1024
34
+ assert 10 < fine_size < 1024
35
+ assert 0 <= rec_overlap < 1
36
+ assert 0 < rec_fast_thr < 1
37
+ self.coarse_size = coarse_size
38
+ self.fine_size = fine_size
39
+ self.overlap = rec_overlap
40
+ self.score_thr = rec_score_thr
41
+ self.fast_thr = rec_fast_thr
42
+
43
+ @torch.no_grad()
44
+ def forward(self, img1, img2, ret='corres', dbg=()):
45
+ img1, sca1 = self.demultiplex_img_trf(img1, force=True)
46
+ img2, sca2 = self.demultiplex_img_trf(img2, force=True)
47
+ input_trfs = (sca1, sca2)
48
+
49
+ # coarse first level with low-res images
50
+ corres = self.coarse_correspondences(img1, img2)
51
+
52
+ # fine level: iterate on HQ blocks
53
+ accu1, accu2 = (self._make_accu(img1), self._make_accu(img2))
54
+ for block1, block2 in tqdm(list(self._enumerate_blocks(img1, img2, corres))):
55
+ # print(f"img1[{block1[}:{}, {}:{}]"
56
+ accus, trfs = tss.SingleScalePUMP.forward(self, block1, block2, ret='raw', dbg=dbg)
57
+ self._update_accu( accu1, accus[0], trfs[0][:2,2] )
58
+ self._update_accu( accu2, accus[1], trfs[1][:2,2] )
59
+
60
+ demul = lambda accu: (accu[:,:,:4].reshape(-1,4).clone(), accu[:,:,4].clone())
61
+ corres = demul(accu1), demul(accu2)
62
+ if dbgfig('corres', dbg): viz_correspondences(img1, img2, *corres, fig='last')
63
+ corres = [(myF.affmul(input_trfs,pos),score) for pos, score in corres] # rectify scaling etc.
64
+ if ret == 'raw': return corres, input_trfs
65
+ return self.reciprocal(*corres)
66
+
67
+ def coarse_correspondences(self, img1, img2, **kw):
68
+ # joint image resize, because relative size is important (multiscale)
69
+ shape1, shape2 = img1.shape[-2:], img2.shape[-2:]
70
+ if max(shape1 + shape2) > self.coarse_size:
71
+ f1 = self.coarse_size / max(shape1)
72
+ f2 = self.coarse_size / max(shape2)
73
+ f = min(f1, f2)
74
+ img1 = myF.imresize( img1, int(0.5+f*max(shape1)) )
75
+ img2 = myF.imresize( img2, int(0.5+f*max(shape2)) )
76
+ else:
77
+ f = 1
78
+
79
+ init_corres = tss.SingleScalePUMP.forward(self, img1, img2, **kw)
80
+ # show_correspondences(img1, img2, init_corres, fig='last')
81
+ corres = init_corres[init_corres[:,4] > self.score_thr]
82
+ print(f" keeping {len(corres)}/{len(init_corres)} corres with score > {self.score_thr} ...")
83
+ return corres
84
+
85
+ def _update_accu(self, accu, update, offset ):
86
+ pos, scores = update
87
+ H, W = scores.shape
88
+ offx, offy = map(lambda i: int(i/4), offset)
89
+ accu = accu[offy:offy+H, offx:offx+W]
90
+ better = accu[:,:,4] < scores
91
+ accu[:,:,4][better] = scores[better].float()
92
+ accu[:,:,0:4][better] = pos.reshape(H,W,4)[better]
93
+
94
+ def _enumerate_blocks(self, img1, img2, corres):
95
+ H1, W1, H2, W2 = img1.shape[1:] + img2.shape[1:]
96
+ size, step = self.fine_size, int(self.overlap * self.fine_size)
97
+ def regular_steps(size):
98
+ if size <= self.fine_size: return [0]
99
+ nb = int(np.ceil(size / step)) - 1 # garranted >= 1
100
+ return (np.linspace(0, size-self.fine_size, nb) / 4 + 0.5).astype(int) * 4
101
+ def translation(x,y):
102
+ res = torch.eye(3, device=img1.device)
103
+ res[0,2] = x
104
+ res[1,2] = y
105
+ return res
106
+ def block2(x2,y2):
107
+ return img2[:,y2:y2+size,x2:x2+size], translation(x2,y2)
108
+ cx1, cy1 = corres[:,0:2].T
109
+
110
+ for y1 in regular_steps(H1):
111
+ for x1 in regular_steps(W1):
112
+ block1 = (img1[:,y1:y1+size,x1:x1+size], translation(x1,y1))
113
+ c2 = corres[(y1<=cy1) & (cy1<y1+size) & (x1<=cx1) & (cx1<x1+size)]
114
+ nb_init = len(c2)
115
+ while len(c2):
116
+ cx2, cy2 = c2[:,2:4].T
117
+ x2, y2 = (int(max(0,min(W2-size,cx2.median()-size//2)) / 4 + 0.5) * 4,
118
+ int(max(0,min(H2-size,cy2.median()-size//2)) / 4 + 0.5) * 4)
119
+ inside = (y2<=cy2) & (cy2<y2+size) & (x2<=cx2) & (cx2<x2+size)
120
+ if not inside.any():
121
+ x2, y2 = c2[np.random.choice(len(c2)),2:4]
122
+ x2 = int(max(0,min(W2-size,x2-size//2)) / 4 + 0.5) * 4
123
+ y2 = int(max(0,min(H2-size,y2-size//2)) / 4 + 0.5) * 4
124
+ inside = (y2<=cy2) & (cy2<y2+size) & (x2<=cx2) & (cx2<x2+size)
125
+
126
+ if inside.sum()/nb_init >= self.fast_thr:
127
+ yield block1, block2(x2,y2)
128
+
129
+ c2 = c2[~inside] # remove
130
+
131
+ def _make_accu(self, img):
132
+ C, H, W = img.shape
133
+ return img.new_zeros(((H+3)//4, (W+3)//4, 5), dtype=torch.float32)
134
+
135
+
136
+
137
+ class Main (tss.Main):
138
+ @staticmethod
139
+ def build_matcher(args, device):
140
+ # set coarse and fine size based on now obsolete --resize argument
141
+ if isinstance(args.resize, int): args.resize = [args.resize]
142
+ if len(args.resize) == 1: args.resize *= 2
143
+ args.rec_coarse_size, args.rec_fine_size = args.resize
144
+ args.resize = 0 # disable it so that image loading does not downsize images
145
+
146
+ options = Main.get_options( args )
147
+
148
+ matcher = RecursivePUMP( coarse_size=args.rec_coarse_size, fine_size=args.rec_fine_size,
149
+ rec_overlap=args.rec_overlap, rec_score_thr=args.rec_score_thr, rec_fast_thr=args.rec_fast_thr,
150
+ **options)
151
+
152
+ return tss.Main.tune_matcher(matcher, **vars(args) ).to(device)
153
+
154
+
155
+ if __name__ == '__main__':
156
+ Main().run_from_args(arg_parser().parse_args())
tools/common.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ import os
6
+ import torch
7
+ import numpy as np
8
+
9
+
10
+ def mkdir_for(file_path):
11
+ dirname = os.path.split(file_path)[0]
12
+ if dirname: os.makedirs(dirname, exist_ok=True)
13
+ return file_path
14
+
15
+
16
+ def model_size(model):
17
+ ''' Computes the number of parameters of the model
18
+ '''
19
+ size = 0
20
+ for weights in model.state_dict().values():
21
+ size += np.prod(weights.shape)
22
+ return size
23
+
24
+
25
+ class cudnn_benchmark:
26
+ " context manager to temporarily disable cudnn benchmark "
27
+ def __init__(self, activate ):
28
+ self.activate = activate
29
+ def __enter__(self):
30
+ self.old_bm = torch.backends.cudnn.benchmark
31
+ torch.backends.cudnn.benchmark = self.activate
32
+ def __exit__(self, *args):
33
+ torch.backends.cudnn.benchmark = self.old_bm
34
+
35
+
36
+ def todevice(x, device, non_blocking=False):
37
+ """ Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
38
+ x: array, tensor, or container of such.
39
+ device: pytorch device or 'numpy'
40
+ """
41
+ if isinstance(x, dict):
42
+ return {k:todevice(v, device) for k,v in x.items()}
43
+
44
+ if isinstance(x, (tuple,list)):
45
+ return type(x)(todevice(e, device) for e in x)
46
+
47
+ if device == 'numpy':
48
+ if isinstance(x, torch.Tensor):
49
+ x = x.detach().cpu().numpy()
50
+ elif x is not None:
51
+ if isinstance(x, np.ndarray):
52
+ x = torch.from_numpy(x)
53
+ x = x.to(device, non_blocking=non_blocking)
54
+ return x
55
+
56
+ def nparray( x ): return todevice(x, 'numpy')
57
+ def cpu( x ): return todevice(x, 'cpu')
58
+ def cuda( x ): return todevice(x, 'cuda')
59
+
60
+
61
+ def image( img, with_trf=False ):
62
+ " convert a torch.Tensor to a numpy image (H, W, 3) "
63
+ def convert_image(img):
64
+ if isinstance(img, torch.Tensor):
65
+ if img.dtype is not torch.uint8:
66
+ img = img * 255
67
+ if img.min() < -10:
68
+ img = img.clone()
69
+ for i, (mean, std) in enumerate(zip([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])):
70
+ img[i] *= std
71
+ img[i] += 255*mean
72
+ img = img.byte()
73
+ if img.shape[0] <= 3:
74
+ img = img.permute(1,2,0)
75
+ return img
76
+
77
+ if isinstance(img, tuple):
78
+ if with_trf:
79
+ return nparray(convert_image(img[0])), nparray(img[1])
80
+ else:
81
+ img = img[0]
82
+ return nparray(convert_image(img))
83
+
84
+
85
+ def image_with_trf( img ):
86
+ return image(img, with_trf=True)
87
+
88
+ class ToTensor:
89
+ " numpy images to float tensors "
90
+ def __call__(self, x):
91
+ assert x.ndim == 4 and x.shape[3] == 3
92
+ if isinstance(x, np.ndarray):
93
+ x = torch.from_numpy(x)
94
+ assert x.dtype == torch.uint8
95
+ return x.permute(0, 3, 1, 2).float() / 255
tools/trainer.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ import pdb; bb = pdb.set_trace
6
+ from tqdm import tqdm
7
+ from collections import defaultdict
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import DataParallel
12
+
13
+ from .common import todevice
14
+
15
+
16
+ class Trainer (nn.Module):
17
+ """ Helper class to train a deep network.
18
+ Overload this class `forward_backward` for your actual needs.
19
+
20
+ Usage:
21
+ train = Trainer(net, loss, optimizer)
22
+ for epoch in range(n_epochs):
23
+ train()
24
+ """
25
+ def __init__(self, net, loss, optimizer, epoch=0):
26
+ super().__init__()
27
+ self.net = net
28
+ self.loss = loss
29
+ self.optimizer = optimizer
30
+ self.epoch = epoch
31
+
32
+ @property
33
+ def device(self):
34
+ return next(self.net.parameters()).device
35
+
36
+ @property
37
+ def model(self):
38
+ return self.net.module if isinstance(self.net, DataParallel) else self.net
39
+
40
+ def distribute(self):
41
+ self.net = DataParallel(self.net) # DataDistributed not implemented yet
42
+
43
+ def __call__(self, data_loader):
44
+ print(f'>> Training (epoch {self.epoch} --> {self.epoch+1})')
45
+ self.net.train()
46
+
47
+ stats = defaultdict(list)
48
+
49
+ for batch in tqdm(data_loader):
50
+ batch = todevice(batch, self.device)
51
+
52
+ # compute gradient and do model update
53
+ self.optimizer.zero_grad()
54
+ details = self.forward_backward(batch)
55
+ self.optimizer.step()
56
+
57
+ for key, val in details.items():
58
+ stats[key].append( val )
59
+
60
+ self.epoch += 1
61
+
62
+ print(" Summary of losses during this epoch:")
63
+ for loss_name, vals in stats.items():
64
+ N = 1 + len(vals)//10
65
+ print(f" - {loss_name:10}: {avg(vals[:N]):.3f} --> {avg(vals[-N:]):.3f} (avg: {avg(vals):.3f})")
66
+
67
+ def forward_backward(self, inputs):
68
+ raise NotImplementedError()
69
+
70
+ def save(self, path):
71
+ print(f"\n>> Saving model to {path}")
72
+
73
+ data = {'model': self.model.state_dict(),
74
+ 'optimizer': self.optimizer.state_dict(),
75
+ 'loss': self.loss.state_dict(),
76
+ 'epoch': self.epoch}
77
+
78
+ torch.save(data, open(path,'wb'))
79
+
80
+ def load(self, path, resume=True):
81
+ print(f">> Loading weights from {path} ...")
82
+ checkpoint = torch.load(path, map_location='cpu')
83
+ assert isinstance(checkpoint, dict)
84
+
85
+ self.net.load_state_dict(checkpoint['model'])
86
+ if resume:
87
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
88
+ self.loss.load_state_dict(checkpoint['optimizer'])
89
+ self.epoch = checkpoint['epoch']
90
+ print(f" Resuming training at Epoch {self.epoch}!")
91
+
92
+
93
+ def get_loss( loss ):
94
+ """ returns a tuple (loss, dictionary of loss details)
95
+ """
96
+ assert isinstance(loss, dict)
97
+ grads = None
98
+
99
+ k,l = next(iter(loss.items())) # first item is assumed to be the main loss
100
+ if isinstance(l, tuple):
101
+ l, grads = l
102
+ loss[k] = l
103
+
104
+ return (l, grads), {k:float(v) for k,v in loss.items()}
105
+
106
+
107
+ def backward( loss ):
108
+ if isinstance(loss, tuple):
109
+ loss, grads = loss
110
+ else:
111
+ loss, grads = (loss, None)
112
+
113
+ assert loss == loss, 'loss is NaN'
114
+
115
+ if grads is None:
116
+ loss.backward()
117
+ else:
118
+ # dictionary of separate subgraphs
119
+ for var,grad in grads:
120
+ var.backward(grad)
121
+ return float(loss)
122
+
123
+
124
+ def avg( lis ):
125
+ return sum(lis) / len(lis)
tools/viz.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ import sys
6
+ from pdb import set_trace as bb
7
+ from PIL import Image
8
+ import numpy as np
9
+
10
+ import matplotlib.pyplot as pl; pl.ion()
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ from core import functional as myF
15
+ from .common import cpu, nparray, image, image_with_trf
16
+
17
+
18
+ def dbgfig(*args, **kwargs):
19
+ assert len(args) >= 2
20
+ dbg = args[-1]
21
+ if isinstance(dbg, str):
22
+ dbg = dbg.split()
23
+ for name in args[:-1]:
24
+ if {name,'all'} & set(dbg):
25
+ return pl.figure(name, **kwargs)
26
+ return False
27
+
28
+
29
+ def noticks(ax=None):
30
+ if ax is None: ax = pl.gca()
31
+ ax.set_xticks(())
32
+ ax.set_yticks(())
33
+ return ax
34
+
35
+
36
+ def plot_grid( corres, ax1, ax2=None, marker='+' ):
37
+ """ corres = Nx2 or Nx4 list of correspondences
38
+ """
39
+ if marker is True: marker = '+'
40
+
41
+ corres = nparray(corres)
42
+ # make beautiful colors
43
+ center = corres[:,[1,0]].mean(axis=0)
44
+ colors = np.arctan2(*(corres[:,[1,0]] - center).T)
45
+ colors = np.int32(64*colors/np.pi) % 128
46
+
47
+ all_colors = np.unique(colors)
48
+ palette = {m:pl.cm.hsv(i/float(len(all_colors))) for i,m in enumerate(all_colors)}
49
+
50
+ for m in all_colors:
51
+ x, y = corres[colors==m,0:2].T
52
+ ax1.plot(x, y, marker, ms=10, mew=2, color=palette[m], scalex=0, scaley=0)
53
+
54
+ if not ax2: return
55
+ for m in all_colors:
56
+ x, y = corres[colors==m,2:4].T
57
+ ax2.plot(x, y, marker, ms=10, mew=2, color=palette[m], scalex=0, scaley=0)
58
+
59
+
60
+ def show_correspondences( img0, img1, corres, F=None, fig='last', show_grid=True, bb=None, clf=False):
61
+ img0, trf0 = img0 if isinstance(img0, tuple) else (img0, torch.eye(3))
62
+ img1, trf1 = img1 if isinstance(img1, tuple) else (img1, torch.eye(3))
63
+ if not bb: pl.ioff()
64
+ fig, axes = pl.subplots(2, 2, num=fig_num(fig, 'viz_corres'))
65
+ for i, ax in enumerate(axes.ravel()):
66
+ if clf: ax.cla()
67
+ noticks(ax).numaxis = i % 2
68
+ ax.imshow( [image(img0),image(img1)][i%2] )
69
+
70
+ if corres.shape == (3,3): # corres is an homography matrix
71
+ from pytools.hfuncs import applyh
72
+ H, W = axes[0,0].images[0].get_size()
73
+ pos1 = np.mgrid[:H,:W].reshape(2,-1)[::-1].T
74
+ pos2 = applyh(corres, pos1)
75
+ corres = np.concatenate((pos1,pos2), axis=-1)
76
+
77
+ inv = np.linalg.inv
78
+ corres = myF.affmul((inv(nparray(trf0)),inv(nparray(trf1))), nparray(corres)) # image are already downscaled
79
+ print(f">> Displaying {len(corres)} correspondences (move you mouse over the images)")
80
+
81
+ (ax1, ax2), (ax3, ax4) = axes
82
+ if corres.shape[-1] > 4:
83
+ corres = corres[corres[:,4]>0,:] # select non-null correspondences
84
+ if show_grid: plot_grid(corres, ax3, ax4, marker=show_grid)
85
+
86
+ def mouse_move(event):
87
+ if event.inaxes==None: return
88
+ numaxis = event.inaxes.numaxis
89
+ if numaxis<0: return
90
+ x,y = event.xdata, event.ydata
91
+ ax1.lines.clear()
92
+ ax2.lines.clear()
93
+ sl = slice(2*numaxis, 2*(numaxis+1))
94
+ n = np.sum((corres[:,sl] - [x,y])**2,axis=1).argmin() # find nearest point
95
+ print("\rdisplaying #%d (%d,%d) --> (%d,%d), score=%g, code=%g" % (n,
96
+ corres[n,0],corres[n,1],corres[n,2],corres[n,3],
97
+ corres[n,4] if corres.shape[-1] > 4 else np.nan,
98
+ corres[n,5] if corres.shape[-1] > 5 else np.nan), end=' '*7);sys.stdout.flush()
99
+ x,y = corres[n,0:2]
100
+ ax1.plot(x, y, '+', ms=10, mew=2, color='blue', scalex=False, scaley=False)
101
+ x,y = corres[n,2:4]
102
+ ax2.plot(x, y, '+', ms=10, mew=2, color='red', scalex=False, scaley=False)
103
+ if F is not None:
104
+ ax = None
105
+ if numaxis == 0:
106
+ line = corres[n,0:2] @ F[:2] + F[2]
107
+ ax = ax2
108
+ if numaxis == 1:
109
+ line = corres[n,2:4] @ F.T[:2] + F.T[2]
110
+ ax = ax1
111
+ if ax:
112
+ x = np.linspace(-10000,10000,2)
113
+ y = (line[2]+line[0]*x) / -line[1]
114
+ ax.plot(x, y, '-', scalex=0, scaley=0)
115
+
116
+ # we redraw only the concerned axes
117
+ renderer = fig.canvas.get_renderer()
118
+ ax1.draw(renderer)
119
+ ax2.draw(renderer)
120
+ fig.canvas.blit(ax1.bbox)
121
+ fig.canvas.blit(ax2.bbox)
122
+
123
+ cid_move = fig.canvas.mpl_connect('motion_notify_event',mouse_move)
124
+ pl.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0.02, hspace=0.02)
125
+ bb() if bb else pl.show()
126
+ fig.canvas.mpl_disconnect(cid_move)
127
+
128
+
129
+ def closest( grid, event ):
130
+ query = (event.xdata, event.ydata)
131
+ n = np.linalg.norm(grid.reshape(-1,2) - query, axis=1).argmin()
132
+ return np.unravel_index(n, grid.shape[:2])
133
+
134
+
135
+ def local_maxima( arr2d, top=5 ):
136
+ maxpooled = F.max_pool2d( arr2d[None, None], 3, padding=1, stride=1)[0,0]
137
+ local_maxima = (arr2d == maxpooled).nonzero()
138
+ order = arr2d[local_maxima.split(1,dim=1)].ravel().argsort()
139
+ return local_maxima[order[-5:]].T
140
+
141
+
142
+ def fig_num( fig, default, clf=False ):
143
+ if fig == 'last': num = pl.gcf().number
144
+ elif fig: num = fig.number
145
+ else: num = default
146
+ if clf: pl.figure(num).clf()
147
+ return num
148
+
149
+
150
+ def viz_correlation_maps( img1, img2, corr, level=0, fig=None, grid1=None, grid2=None, show_grid=False, bb=bb, **kw ):
151
+ fig, ((ax1, ax2), (ax4, ax3)) = pl.subplots(2, 2, num=fig_num(fig, 'viz_correlation_maps', clf=True))
152
+ img1 = image(img1)
153
+ img2 = image(img2)
154
+ noticks(ax1).imshow( img1 )
155
+ noticks(ax2).imshow( img2 )
156
+ ax4.hist(corr.ravel()[7:7777777:7].cpu().numpy(), bins=50)
157
+
158
+ if isinstance(corr, tuple):
159
+ H1, W1 = corr.grid.shape[:2]
160
+ corr = torch.from_numpy(corr.res_map).view(H1,W1,*corr.res_map.shape[-2:])
161
+
162
+ if grid1 is None:
163
+ s1 = int(0.5 + np.sqrt(img1.size / (3 * corr[...,0,0].numel()))) # scale factor between img1 and corr
164
+ grid1 = nparray(torch.ones_like(corr[:,:,0,0]).nonzero()*s1)[:,1::-1]
165
+ if level == 0: grid1 += s1//2
166
+ if show_grid: plot_grid(grid1, ax1)
167
+ grid1 = nparray(grid1).reshape(*corr[:,:,0,0].shape,2)
168
+
169
+ if grid2 is None:
170
+ s2 = int(0.5 + np.sqrt(img2.size / (3 * corr[0,0,...].numel()))) # scale factor between img2 and corr
171
+ grid2 = nparray(torch.ones_like(corr[0,0]).nonzero()*s2)[:,::-1]
172
+ grid2 = nparray(grid2).reshape(*corr.shape[2:],2)
173
+
174
+ def mouse_move(ev):
175
+ if ev.inaxes is ax1:
176
+ ax3.images.clear()
177
+ n = closest(grid1, ev)
178
+ ax3.imshow(corr[n].cpu().float(), vmin=0, **kw)
179
+
180
+ # find local maxima
181
+ lm = nparray(local_maxima(corr[n]))
182
+ for ax in (ax3, ax2):
183
+ if ax is ax2 and not show_grid:
184
+ ax1.lines.clear()
185
+ ax1.plot(*grid1[n], 'xr', ms=10, scalex=0, scaley=0)
186
+ ax.lines.clear()
187
+ x, y = grid2[y,x].T if ax is ax2 else lm[::-1]
188
+ if ax is not ax3:
189
+ ax.plot(x, y, 'xr', ms=10, scalex=0, scaley=0, label='local maxima')
190
+ print(f"\rCorr channel {n}. Min={corr[n].min():g}, Avg={corr[n].mean():g}, Max={corr[n].max():g} ", end='')
191
+
192
+ mouse_move(FakeEvent(0,0,inaxes=ax1))
193
+ cid_move = fig.canvas.mpl_connect('motion_notify_event', mouse_move)
194
+ pl.subplots_adjust(0,0,1,1,0,0)
195
+ pl.sca(ax4)
196
+ if bb: bb(); fig.canvas.mpl_disconnect(cid_move)
197
+
198
+ def viz_correspondences( img1, img2, corres1, corres2, fig=None ):
199
+ img1, img2 = map(image, (img1, img2))
200
+ fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = pl.subplots(3,2, num=fig_num(fig, 'viz_correspondences'))
201
+ for ax in fig.axes: noticks(ax)
202
+ ax1.imshow( img1 )
203
+ ax2.imshow( img2 )
204
+ ax3.imshow( img1 )
205
+ ax4.imshow( img2 )
206
+ corres1, corres2 = map(cpu, (corres1, corres2))
207
+ plot_grid( corres1[0], ax1, ax2 )
208
+ plot_grid( corres2[0], ax3, ax4 )
209
+
210
+ corres1, corres2 = corres1[1].float(), corres2[1].float()
211
+ ceiling = np.ceil(max(corres1.max(), corres2.max()).item())
212
+ ax5.imshow( corres1, vmin=0, vmax=ceiling )
213
+ ax6.imshow( corres2, vmin=0, vmax=ceiling )
214
+ bb()
215
+
216
+
217
+ class FakeEvent:
218
+ def __init__(self, xdata, ydata, **kw):
219
+ self.xdata = xdata
220
+ self.ydata = ydata
221
+ for name, val in kw.items():
222
+ setattr(self, name, val)
223
+
224
+
225
+ def show_random_pairs( db, pair_idxs=None, **kw ):
226
+ print('Showing random pairs from', db)
227
+
228
+ if pair_idxs is None:
229
+ pair_idxs = np.random.permutation(len(db))
230
+
231
+ for pair_idx in pair_idxs:
232
+ print(f'{pair_idx=}')
233
+ try:
234
+ img1_path, img2_path = map(db.imgs.get_image_path, db.pairs[pair_idx])
235
+ print(f'{img1_path=}\n{img2_path=}')
236
+ if hasattr(db, 'get_corres_path'):
237
+ print(f'corres_path = {db.get_corres_path(pair_idx)}')
238
+ except: pass
239
+ (img1, img2), gt = db[pair_idx]
240
+
241
+ if 'corres' in gt:
242
+ corres = gt['corres']
243
+ else:
244
+ # make corres from homography
245
+ from datasets.utils import corres_from_homography
246
+ corres = corres_from_homography(gt['homography'], *img1.size)
247
+
248
+ show_correspondences(img1, img2, corres, **kw)
249
+
250
+
251
+ if __name__=='__main__':
252
+ import argparse
253
+ import test_singlescale as pump
254
+
255
+ parser = argparse.ArgumentParser('Correspondence visualization')
256
+ parser.add_argument('--img1', required=True, help='path to first image')
257
+ parser.add_argument('--img2', required=True, help='path to second image')
258
+ parser.add_argument('--corres', required=True, help='path to correspondences')
259
+ args = parser.parse_args()
260
+
261
+ corres = np.load(args.corres)['corres']
262
+
263
+ args.resize = 0 # don't resize images
264
+ imgs = tuple(map(image, pump.Main.load_images(args)))
265
+
266
+ show_correspondences(*imgs, corres)
train.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-present NAVER Corp.
2
+ # CC BY-NC-SA 4.0
3
+ # Available only for non-commercial use
4
+
5
+ from pdb import set_trace as bb
6
+ import os
7
+ import torch
8
+ import torch.optim as optim
9
+ import torchvision.transforms as tvf
10
+
11
+ from tools import common, trainer
12
+ from datasets import *
13
+ from core.conv_mixer import ConvMixer
14
+ from core.losses import *
15
+
16
+
17
+ def parse_args():
18
+ import argparse
19
+ parser = argparse.ArgumentParser("Script to train PUMP")
20
+
21
+ parser.add_argument("--pretrained", type=str, default="", help='pretrained model path')
22
+ parser.add_argument("--save-path", type=str, required=True, help='directory to save model')
23
+
24
+ parser.add_argument("--epochs", type=int, default=50, help='number of training epochs')
25
+ parser.add_argument("--batch-size", "--bs", type=int, default=16, help="batch size")
26
+ parser.add_argument("--learning-rate", "--lr", type=str, default=1e-4)
27
+ parser.add_argument("--weight-decay", "--wd", type=float, default=5e-4)
28
+
29
+ parser.add_argument("--threads", type=int, default=8, help='number of worker threads')
30
+ parser.add_argument("--device", default='cuda')
31
+
32
+ args = parser.parse_args()
33
+ return args
34
+
35
+
36
+ def main( args ):
37
+ device = args.device
38
+ common.mkdir_for(args.save_path)
39
+
40
+ # Create data loader
41
+ db = BalancedCatImagePairs(
42
+ 3125, SyntheticImagePairs(RandomWebImages(0,52),distort='RandomTilting(0.5)'),
43
+ 4875, SyntheticImagePairs(SfM120k_Images(),distort='RandomTilting(0.5)'),
44
+ 8000, SfM120k_Pairs())
45
+
46
+ db = FastPairLoader(db,
47
+ crop=256, transform='RandomRotation(20), RandomScale(256,1536,ar=1.3,can_upscale=True), PixelNoise(25)',
48
+ p_swap=0.5, p_flip=0.5, scale_jitter=0.5)
49
+
50
+ print("Training image database =", db)
51
+ data_loader = torch.utils.data.DataLoader(db, batch_size=args.batch_size, shuffle=True,
52
+ num_workers=args.threads, collate_fn=collate_ordered, pin_memory=False, drop_last=True,
53
+ worker_init_fn=WorkerWithRngInit())
54
+
55
+ # create network
56
+ net = ConvMixer(output_dim=128, hidden_dim=512, depth=7, patch_size=4, kernel_size=9)
57
+ print(f"\n>> Creating {type(net).__name__} net ( Model size: {common.model_size(net)/1e6:.1f}M parameters )")
58
+
59
+ # create losses
60
+ loss = MultiLoss(alpha=0.3,
61
+ loss_sup = PixelAPLoss(nq=20, inner_bw=True, sampler=NghSampler(ngh=7)),
62
+ loss_unsup = DeepMatchingLoss(eps=0.03))
63
+
64
+ # create optimizer
65
+ optimizer = optim.Adam( [p for p in net.parameters() if p.requires_grad],
66
+ lr=args.learning_rate, weight_decay=args.weight_decay)
67
+
68
+ train = MyTrainer(net, loss, optimizer).to(device)
69
+
70
+ # initialization
71
+ final_model_path = osp.join(args.save_path,'model.pt')
72
+ last_model_path = osp.join(args.save_path,'model.pt.last')
73
+ if osp.exists( final_model_path ):
74
+ print('Already trained, nothing to do!')
75
+ return
76
+ elif args.pretrained:
77
+ train.load( args.pretrained )
78
+ elif osp.exists( last_model_path ):
79
+ train.load( last_model_path )
80
+
81
+ train = train.to(args.device)
82
+ if ',' in os.environ.get('CUDA_VISIBLE_DEVICES',''):
83
+ train.distribute()
84
+
85
+ # Training loop #
86
+ while train.epoch < args.epochs:
87
+ # shuffle dataset (select new pairs)
88
+ data_loader.dataset.set_epoch(train.epoch)
89
+
90
+ train(data_loader)
91
+
92
+ train.save(last_model_path)
93
+
94
+ # save final model
95
+ torch.save(train.model.state_dict(), open(final_model_path,'wb'))
96
+
97
+
98
+ totensor = tvf.Compose([
99
+ common.ToTensor(),
100
+ tvf.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
101
+ ])
102
+
103
+ class MyTrainer (trainer.Trainer):
104
+ """ This class implements the network training.
105
+ Below is the function I need to overload to explain how to do the backprop.
106
+ """
107
+ def forward_backward(self, inputs):
108
+ assert torch.is_grad_enabled() and self.net.training
109
+
110
+ (img1, img2), labels = inputs
111
+ output1 = self.net(totensor(img1))
112
+ output2 = self.net(totensor(img2))
113
+
114
+ loss, details = trainer.get_loss(self.loss(output1, output2, img1=img1, img2=img2, **labels))
115
+ trainer.backward(loss)
116
+ return details
117
+
118
+
119
+
120
+ if __name__ == '__main__':
121
+ main(parse_args())