2nzi commited on
Commit
3d1f2c9
·
1 Parent(s): efb3a7a

Upload 63 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +68 -0
  2. .gitignore +109 -0
  3. Dockerfile +10 -0
  4. LICENSE +339 -0
  5. PnLCalib.yml +235 -0
  6. README.md +134 -10
  7. api.py +162 -0
  8. config.yaml +33 -0
  9. config/field_config.py +164 -0
  10. config/hrnetv2_w48.yaml +35 -0
  11. config/hrnetv2_w48_l.yaml +35 -0
  12. data/__init__.py +1 -0
  13. data/line_data.py +115 -0
  14. get_camera_params.py +201 -0
  15. inference.py +286 -0
  16. model/cls_hrnet.py +479 -0
  17. model/cls_hrnet_l.py +478 -0
  18. model/dataloader.py +219 -0
  19. model/dataloader_l.py +202 -0
  20. model/losses.py +221 -0
  21. model/metrics.py +312 -0
  22. model/transforms.py +364 -0
  23. model/transformsWC.py +212 -0
  24. model/transformsWC_l.py +208 -0
  25. model/transforms_l.py +360 -0
  26. requirements.txt +44 -0
  27. run_api.py +13 -0
  28. scripts/eval_tswc.py +175 -0
  29. scripts/eval_wc14.py +154 -0
  30. scripts/inference_sn.py +186 -0
  31. scripts/inference_tswc.py +145 -0
  32. scripts/inference_wc14.py +148 -0
  33. scripts/run_pipeline_sn22.sh +23 -0
  34. scripts/run_pipeline_sn23.sh +23 -0
  35. scripts/run_pipeline_tswc.sh +20 -0
  36. scripts/run_pipeline_wc14.sh +20 -0
  37. scripts/run_pipeline_wc14_3D.sh +23 -0
  38. sn_calibration/ChallengeRules.md +44 -0
  39. sn_calibration/README.md +440 -0
  40. sn_calibration/requirements.txt +17 -0
  41. sn_calibration/resources/mean.npy +3 -0
  42. sn_calibration/resources/std.npy +3 -0
  43. sn_calibration/src/baseline_cameras.py +251 -0
  44. sn_calibration/src/camera.py +402 -0
  45. sn_calibration/src/dataloader.py +122 -0
  46. sn_calibration/src/detect_extremities.py +314 -0
  47. sn_calibration/src/evalai_camera.py +133 -0
  48. sn_calibration/src/evaluate_camera.py +365 -0
  49. sn_calibration/src/evaluate_extremities.py +274 -0
  50. sn_calibration/src/soccerpitch.py +528 -0
.dockerignore ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Exclure les environnements virtuels
2
+ .venv/
3
+ env/
4
+ venv/
5
+ ENV/
6
+ env.bak/
7
+ venv.bak/
8
+
9
+ # Python cache
10
+ __pycache__/
11
+ *.pyc
12
+ *.pyo
13
+ *.pyd
14
+ *.pyc
15
+ *.pyw
16
+ *.pyz
17
+
18
+ # OS / IDE files
19
+ .DS_Store
20
+ .vscode/
21
+ .idea/
22
+ *.swp
23
+ *.swo
24
+ *~
25
+
26
+ # Git
27
+ .git/
28
+ .gitignore
29
+
30
+ # Logs & temp
31
+ *.log
32
+ *.tmp
33
+ *.temp
34
+ .cache/
35
+ tmp/
36
+
37
+ # Données et résultats intermédiaires
38
+ input/
39
+ output/
40
+ *.zip
41
+ *.png
42
+ *.jpg
43
+ *.jpeg
44
+ *.npy
45
+ *.pkl
46
+ *.h5
47
+ *.pth
48
+ *.pt
49
+
50
+ # Fichiers spécifiques à ton projet
51
+ data/private/
52
+ UnderstandScripts/
53
+ env_test/
54
+ requirements_copy.txt
55
+ requirements_clean.txt
56
+
57
+ # Vercel & Docker fichiers
58
+ .vercel/
59
+ Dockerfile
60
+ docker-compose.yml
61
+ .dockerignore
62
+
63
+ # Vidéos
64
+ *.mp4
65
+
66
+
67
+
68
+ models/
.gitignore ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ share/python-wheels/
20
+ *.egg-info/
21
+ .installed.cfg
22
+ *.egg
23
+ MANIFEST
24
+
25
+ # Virtual environments
26
+ .env
27
+ .venv
28
+ env/
29
+ venv/
30
+ ENV/
31
+ env.bak/
32
+ venv.bak/
33
+ .venv/
34
+
35
+ # IDE
36
+ .vscode/
37
+ .idea/
38
+ *.swp
39
+ *.swo
40
+ *~
41
+
42
+ # OS
43
+ .DS_Store
44
+ .DS_Store?
45
+ ._*
46
+ .Spotlight-V100
47
+ .Trashes
48
+ ehthumbs.db
49
+ Thumbs.db
50
+
51
+ # Logs
52
+ *.log
53
+ logs/
54
+
55
+ # Temporary files
56
+ *.tmp
57
+ *.temp
58
+ .cache/
59
+ tmp/
60
+
61
+ # Data files (si sensibles)
62
+ data/private/
63
+ *.pkl
64
+ *.h5
65
+
66
+ # Model files (si trop gros)
67
+ *.pth
68
+ *.pt
69
+
70
+
71
+ # Environment variables
72
+ .env.local
73
+ .env.production
74
+
75
+ # Jupyter
76
+ .ipynb_checkpoints/
77
+
78
+ # pytest
79
+ .pytest_cache/
80
+
81
+ # Coverage
82
+ htmlcov/
83
+ .coverage
84
+ .coverage.*
85
+
86
+ # Docker (optionnel selon votre besoin)
87
+ # Dockerfile
88
+ # .dockerignore
89
+
90
+ # Vercel
91
+ .vercel
92
+
93
+ input/
94
+ output/
95
+ *.zip
96
+ *.png
97
+ *.jpg
98
+ *.npy
99
+
100
+ UnderstandScripts/
101
+ inference/
102
+ models/
103
+
104
+ env_test/
105
+
106
+ requirements_copy.txt
107
+
108
+ requirements_clean.txt
109
+
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8000"]
LICENSE ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 2, June 1991
3
+
4
+ Copyright (C) 1989, 1991 Free Software Foundation, Inc.,
5
+ 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
6
+ Everyone is permitted to copy and distribute verbatim copies
7
+ of this license document, but changing it is not allowed.
8
+
9
+ Preamble
10
+
11
+ The licenses for most software are designed to take away your
12
+ freedom to share and change it. By contrast, the GNU General Public
13
+ License is intended to guarantee your freedom to share and change free
14
+ software--to make sure the software is free for all its users. This
15
+ General Public License applies to most of the Free Software
16
+ Foundation's software and to any other program whose authors commit to
17
+ using it. (Some other Free Software Foundation software is covered by
18
+ the GNU Lesser General Public License instead.) You can apply it to
19
+ your programs, too.
20
+
21
+ When we speak of free software, we are referring to freedom, not
22
+ price. Our General Public Licenses are designed to make sure that you
23
+ have the freedom to distribute copies of free software (and charge for
24
+ this service if you wish), that you receive source code or can get it
25
+ if you want it, that you can change the software or use pieces of it
26
+ in new free programs; and that you know you can do these things.
27
+
28
+ To protect your rights, we need to make restrictions that forbid
29
+ anyone to deny you these rights or to ask you to surrender the rights.
30
+ These restrictions translate to certain responsibilities for you if you
31
+ distribute copies of the software, or if you modify it.
32
+
33
+ For example, if you distribute copies of such a program, whether
34
+ gratis or for a fee, you must give the recipients all the rights that
35
+ you have. You must make sure that they, too, receive or can get the
36
+ source code. And you must show them these terms so they know their
37
+ rights.
38
+
39
+ We protect your rights with two steps: (1) copyright the software, and
40
+ (2) offer you this license which gives you legal permission to copy,
41
+ distribute and/or modify the software.
42
+
43
+ Also, for each author's protection and ours, we want to make certain
44
+ that everyone understands that there is no warranty for this free
45
+ software. If the software is modified by someone else and passed on, we
46
+ want its recipients to know that what they have is not the original, so
47
+ that any problems introduced by others will not reflect on the original
48
+ authors' reputations.
49
+
50
+ Finally, any free program is threatened constantly by software
51
+ patents. We wish to avoid the danger that redistributors of a free
52
+ program will individually obtain patent licenses, in effect making the
53
+ program proprietary. To prevent this, we have made it clear that any
54
+ patent must be licensed for everyone's free use or not licensed at all.
55
+
56
+ The precise terms and conditions for copying, distribution and
57
+ modification follow.
58
+
59
+ GNU GENERAL PUBLIC LICENSE
60
+ TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
61
+
62
+ 0. This License applies to any program or other work which contains
63
+ a notice placed by the copyright holder saying it may be distributed
64
+ under the terms of this General Public License. The "Program", below,
65
+ refers to any such program or work, and a "work based on the Program"
66
+ means either the Program or any derivative work under copyright law:
67
+ that is to say, a work containing the Program or a portion of it,
68
+ either verbatim or with modifications and/or translated into another
69
+ language. (Hereinafter, translation is included without limitation in
70
+ the term "modification".) Each licensee is addressed as "you".
71
+
72
+ Activities other than copying, distribution and modification are not
73
+ covered by this License; they are outside its scope. The act of
74
+ running the Program is not restricted, and the output from the Program
75
+ is covered only if its contents constitute a work based on the
76
+ Program (independent of having been made by running the Program).
77
+ Whether that is true depends on what the Program does.
78
+
79
+ 1. You may copy and distribute verbatim copies of the Program's
80
+ source code as you receive it, in any medium, provided that you
81
+ conspicuously and appropriately publish on each copy an appropriate
82
+ copyright notice and disclaimer of warranty; keep intact all the
83
+ notices that refer to this License and to the absence of any warranty;
84
+ and give any other recipients of the Program a copy of this License
85
+ along with the Program.
86
+
87
+ You may charge a fee for the physical act of transferring a copy, and
88
+ you may at your option offer warranty protection in exchange for a fee.
89
+
90
+ 2. You may modify your copy or copies of the Program or any portion
91
+ of it, thus forming a work based on the Program, and copy and
92
+ distribute such modifications or work under the terms of Section 1
93
+ above, provided that you also meet all of these conditions:
94
+
95
+ a) You must cause the modified files to carry prominent notices
96
+ stating that you changed the files and the date of any change.
97
+
98
+ b) You must cause any work that you distribute or publish, that in
99
+ whole or in part contains or is derived from the Program or any
100
+ part thereof, to be licensed as a whole at no charge to all third
101
+ parties under the terms of this License.
102
+
103
+ c) If the modified program normally reads commands interactively
104
+ when run, you must cause it, when started running for such
105
+ interactive use in the most ordinary way, to print or display an
106
+ announcement including an appropriate copyright notice and a
107
+ notice that there is no warranty (or else, saying that you provide
108
+ a warranty) and that users may redistribute the program under
109
+ these conditions, and telling the user how to view a copy of this
110
+ License. (Exception: if the Program itself is interactive but
111
+ does not normally print such an announcement, your work based on
112
+ the Program is not required to print an announcement.)
113
+
114
+ These requirements apply to the modified work as a whole. If
115
+ identifiable sections of that work are not derived from the Program,
116
+ and can be reasonably considered independent and separate works in
117
+ themselves, then this License, and its terms, do not apply to those
118
+ sections when you distribute them as separate works. But when you
119
+ distribute the same sections as part of a whole which is a work based
120
+ on the Program, the distribution of the whole must be on the terms of
121
+ this License, whose permissions for other licensees extend to the
122
+ entire whole, and thus to each and every part regardless of who wrote it.
123
+
124
+ Thus, it is not the intent of this section to claim rights or contest
125
+ your rights to work written entirely by you; rather, the intent is to
126
+ exercise the right to control the distribution of derivative or
127
+ collective works based on the Program.
128
+
129
+ In addition, mere aggregation of another work not based on the Program
130
+ with the Program (or with a work based on the Program) on a volume of
131
+ a storage or distribution medium does not bring the other work under
132
+ the scope of this License.
133
+
134
+ 3. You may copy and distribute the Program (or a work based on it,
135
+ under Section 2) in object code or executable form under the terms of
136
+ Sections 1 and 2 above provided that you also do one of the following:
137
+
138
+ a) Accompany it with the complete corresponding machine-readable
139
+ source code, which must be distributed under the terms of Sections
140
+ 1 and 2 above on a medium customarily used for software interchange; or,
141
+
142
+ b) Accompany it with a written offer, valid for at least three
143
+ years, to give any third party, for a charge no more than your
144
+ cost of physically performing source distribution, a complete
145
+ machine-readable copy of the corresponding source code, to be
146
+ distributed under the terms of Sections 1 and 2 above on a medium
147
+ customarily used for software interchange; or,
148
+
149
+ c) Accompany it with the information you received as to the offer
150
+ to distribute corresponding source code. (This alternative is
151
+ allowed only for noncommercial distribution and only if you
152
+ received the program in object code or executable form with such
153
+ an offer, in accord with Subsection b above.)
154
+
155
+ The source code for a work means the preferred form of the work for
156
+ making modifications to it. For an executable work, complete source
157
+ code means all the source code for all modules it contains, plus any
158
+ associated interface definition files, plus the scripts used to
159
+ control compilation and installation of the executable. However, as a
160
+ special exception, the source code distributed need not include
161
+ anything that is normally distributed (in either source or binary
162
+ form) with the major components (compiler, kernel, and so on) of the
163
+ operating system on which the executable runs, unless that component
164
+ itself accompanies the executable.
165
+
166
+ If distribution of executable or object code is made by offering
167
+ access to copy from a designated place, then offering equivalent
168
+ access to copy the source code from the same place counts as
169
+ distribution of the source code, even though third parties are not
170
+ compelled to copy the source along with the object code.
171
+
172
+ 4. You may not copy, modify, sublicense, or distribute the Program
173
+ except as expressly provided under this License. Any attempt
174
+ otherwise to copy, modify, sublicense or distribute the Program is
175
+ void, and will automatically terminate your rights under this License.
176
+ However, parties who have received copies, or rights, from you under
177
+ this License will not have their licenses terminated so long as such
178
+ parties remain in full compliance.
179
+
180
+ 5. You are not required to accept this License, since you have not
181
+ signed it. However, nothing else grants you permission to modify or
182
+ distribute the Program or its derivative works. These actions are
183
+ prohibited by law if you do not accept this License. Therefore, by
184
+ modifying or distributing the Program (or any work based on the
185
+ Program), you indicate your acceptance of this License to do so, and
186
+ all its terms and conditions for copying, distributing or modifying
187
+ the Program or works based on it.
188
+
189
+ 6. Each time you redistribute the Program (or any work based on the
190
+ Program), the recipient automatically receives a license from the
191
+ original licensor to copy, distribute or modify the Program subject to
192
+ these terms and conditions. You may not impose any further
193
+ restrictions on the recipients' exercise of the rights granted herein.
194
+ You are not responsible for enforcing compliance by third parties to
195
+ this License.
196
+
197
+ 7. If, as a consequence of a court judgment or allegation of patent
198
+ infringement or for any other reason (not limited to patent issues),
199
+ conditions are imposed on you (whether by court order, agreement or
200
+ otherwise) that contradict the conditions of this License, they do not
201
+ excuse you from the conditions of this License. If you cannot
202
+ distribute so as to satisfy simultaneously your obligations under this
203
+ License and any other pertinent obligations, then as a consequence you
204
+ may not distribute the Program at all. For example, if a patent
205
+ license would not permit royalty-free redistribution of the Program by
206
+ all those who receive copies directly or indirectly through you, then
207
+ the only way you could satisfy both it and this License would be to
208
+ refrain entirely from distribution of the Program.
209
+
210
+ If any portion of this section is held invalid or unenforceable under
211
+ any particular circumstance, the balance of the section is intended to
212
+ apply and the section as a whole is intended to apply in other
213
+ circumstances.
214
+
215
+ It is not the purpose of this section to induce you to infringe any
216
+ patents or other property right claims or to contest validity of any
217
+ such claims; this section has the sole purpose of protecting the
218
+ integrity of the free software distribution system, which is
219
+ implemented by public license practices. Many people have made
220
+ generous contributions to the wide range of software distributed
221
+ through that system in reliance on consistent application of that
222
+ system; it is up to the author/donor to decide if he or she is willing
223
+ to distribute software through any other system and a licensee cannot
224
+ impose that choice.
225
+
226
+ This section is intended to make thoroughly clear what is believed to
227
+ be a consequence of the rest of this License.
228
+
229
+ 8. If the distribution and/or use of the Program is restricted in
230
+ certain countries either by patents or by copyrighted interfaces, the
231
+ original copyright holder who places the Program under this License
232
+ may add an explicit geographical distribution limitation excluding
233
+ those countries, so that distribution is permitted only in or among
234
+ countries not thus excluded. In such case, this License incorporates
235
+ the limitation as if written in the body of this License.
236
+
237
+ 9. The Free Software Foundation may publish revised and/or new versions
238
+ of the General Public License from time to time. Such new versions will
239
+ be similar in spirit to the present version, but may differ in detail to
240
+ address new problems or concerns.
241
+
242
+ Each version is given a distinguishing version number. If the Program
243
+ specifies a version number of this License which applies to it and "any
244
+ later version", you have the option of following the terms and conditions
245
+ either of that version or of any later version published by the Free
246
+ Software Foundation. If the Program does not specify a version number of
247
+ this License, you may choose any version ever published by the Free Software
248
+ Foundation.
249
+
250
+ 10. If you wish to incorporate parts of the Program into other free
251
+ programs whose distribution conditions are different, write to the author
252
+ to ask for permission. For software which is copyrighted by the Free
253
+ Software Foundation, write to the Free Software Foundation; we sometimes
254
+ make exceptions for this. Our decision will be guided by the two goals
255
+ of preserving the free status of all derivatives of our free software and
256
+ of promoting the sharing and reuse of software generally.
257
+
258
+ NO WARRANTY
259
+
260
+ 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY
261
+ FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN
262
+ OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES
263
+ PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED
264
+ OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
265
+ MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS
266
+ TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE
267
+ PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,
268
+ REPAIR OR CORRECTION.
269
+
270
+ 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
271
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR
272
+ REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES,
273
+ INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING
274
+ OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED
275
+ TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY
276
+ YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER
277
+ PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE
278
+ POSSIBILITY OF SUCH DAMAGES.
279
+
280
+ END OF TERMS AND CONDITIONS
281
+
282
+ How to Apply These Terms to Your New Programs
283
+
284
+ If you develop a new program, and you want it to be of the greatest
285
+ possible use to the public, the best way to achieve this is to make it
286
+ free software which everyone can redistribute and change under these terms.
287
+
288
+ To do so, attach the following notices to the program. It is safest
289
+ to attach them to the start of each source file to most effectively
290
+ convey the exclusion of warranty; and each file should have at least
291
+ the "copyright" line and a pointer to where the full notice is found.
292
+
293
+ <one line to give the program's name and a brief idea of what it does.>
294
+ Copyright (C) <year> <name of author>
295
+
296
+ This program is free software; you can redistribute it and/or modify
297
+ it under the terms of the GNU General Public License as published by
298
+ the Free Software Foundation; either version 2 of the License, or
299
+ (at your option) any later version.
300
+
301
+ This program is distributed in the hope that it will be useful,
302
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
303
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
304
+ GNU General Public License for more details.
305
+
306
+ You should have received a copy of the GNU General Public License along
307
+ with this program; if not, write to the Free Software Foundation, Inc.,
308
+ 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
309
+
310
+ Also add information on how to contact you by electronic and paper mail.
311
+
312
+ If the program is interactive, make it output a short notice like this
313
+ when it starts in an interactive mode:
314
+
315
+ Gnomovision version 69, Copyright (C) year name of author
316
+ Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
317
+ This is free software, and you are welcome to redistribute it
318
+ under certain conditions; type `show c' for details.
319
+
320
+ The hypothetical commands `show w' and `show c' should show the appropriate
321
+ parts of the General Public License. Of course, the commands you use may
322
+ be called something other than `show w' and `show c'; they could even be
323
+ mouse-clicks or menu items--whatever suits your program.
324
+
325
+ You should also get your employer (if you work as a programmer) or your
326
+ school, if any, to sign a "copyright disclaimer" for the program, if
327
+ necessary. Here is a sample; alter the names:
328
+
329
+ Yoyodyne, Inc., hereby disclaims all copyright interest in the program
330
+ `Gnomovision' (which makes passes at compilers) written by James Hacker.
331
+
332
+ <signature of Ty Coon>, 1 April 1989
333
+ Ty Coon, President of Vice
334
+
335
+ This General Public License does not permit incorporating your program into
336
+ proprietary programs. If your program is a subroutine library, you may
337
+ consider it more useful to permit linking proprietary applications with the
338
+ library. If this is what you want to do, use the GNU Lesser General
339
+ Public License instead of this License.
PnLCalib.yml ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: PnLCalib
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - _libgcc_mutex=0.1=conda_forge
9
+ - _openmp_mutex=4.5=2_kmp_llvm
10
+ - alsa-lib=1.2.8=h166bdaf_0
11
+ - attr=2.5.1=h166bdaf_1
12
+ - blas=2.116=mkl
13
+ - blas-devel=3.9.0=16_linux64_mkl
14
+ - brotli=1.1.0=hd590300_1
15
+ - brotli-bin=1.1.0=hd590300_1
16
+ - brotli-python=1.1.0=py39h3d6467e_1
17
+ - bzip2=1.0.8=hd590300_5
18
+ - ca-certificates=2024.6.2=hbcca054_0
19
+ - cairo=1.16.0=ha61ee94_1014
20
+ - certifi=2024.6.2=pyhd8ed1ab_0
21
+ - cffi=1.16.0=py39h7a31438_0
22
+ - charset-normalizer=3.3.2=pyhd8ed1ab_0
23
+ - colorama=0.4.6=pyhd8ed1ab_0
24
+ - contourpy=1.2.1=py39h7633fee_0
25
+ - cuda-cudart=12.1.105=0
26
+ - cuda-cupti=12.1.105=0
27
+ - cuda-libraries=12.1.0=0
28
+ - cuda-nvrtc=12.1.105=0
29
+ - cuda-nvtx=12.1.105=0
30
+ - cuda-opencl=12.5.39=0
31
+ - cuda-runtime=12.1.0=0
32
+ - cuda-version=12.5=3
33
+ - cycler=0.12.1=pyhd8ed1ab_0
34
+ - dbus=1.13.6=h5008d03_3
35
+ - expat=2.6.2=h59595ed_0
36
+ - ffmpeg=4.3=hf484d3e_0
37
+ - fftw=3.3.10=nompi_hf1063bd_110
38
+ - filelock=3.15.4=pyhd8ed1ab_0
39
+ - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
40
+ - font-ttf-inconsolata=3.000=h77eed37_0
41
+ - font-ttf-source-code-pro=2.038=h77eed37_0
42
+ - font-ttf-ubuntu=0.83=h77eed37_2
43
+ - fontconfig=2.14.2=h14ed4e7_0
44
+ - fonts-conda-ecosystem=1=0
45
+ - fonts-conda-forge=1=0
46
+ - fonttools=4.53.0=py39hd3abc70_0
47
+ - freetype=2.12.1=h267a509_2
48
+ - geos=3.12.1=h59595ed_0
49
+ - gettext=0.22.5=h59595ed_2
50
+ - gettext-tools=0.22.5=h59595ed_2
51
+ - glib=2.80.2=hf974151_0
52
+ - glib-tools=2.80.2=hb6ce0ca_0
53
+ - gmp=6.3.0=hac33072_2
54
+ - gmpy2=2.1.5=py39h048c657_1
55
+ - gnutls=3.6.13=h85f3911_1
56
+ - graphite2=1.3.13=h59595ed_1003
57
+ - gst-plugins-base=1.22.0=h4243ec0_2
58
+ - gstreamer=1.22.0=h25f0c4b_2
59
+ - gstreamer-orc=0.4.38=hd590300_0
60
+ - h2=4.1.0=pyhd8ed1ab_0
61
+ - harfbuzz=6.0.0=h8e241bc_0
62
+ - hpack=4.0.0=pyh9f0ad1d_0
63
+ - hyperframe=6.0.1=pyhd8ed1ab_0
64
+ - icu=70.1=h27087fc_0
65
+ - idna=3.7=pyhd8ed1ab_0
66
+ - importlib-resources=6.4.0=pyhd8ed1ab_0
67
+ - importlib_resources=6.4.0=pyhd8ed1ab_0
68
+ - jack=1.9.22=h11f4161_0
69
+ - jinja2=3.1.4=pyhd8ed1ab_0
70
+ - jpeg=9e=h166bdaf_2
71
+ - keyutils=1.6.1=h166bdaf_0
72
+ - kiwisolver=1.4.5=py39h7633fee_1
73
+ - krb5=1.20.1=h81ceb04_0
74
+ - lame=3.100=h166bdaf_1003
75
+ - lcms2=2.15=hfd0df8a_0
76
+ - ld_impl_linux-64=2.40=hf3520f5_7
77
+ - lerc=4.0.0=h27087fc_0
78
+ - libasprintf=0.22.5=h661eb56_2
79
+ - libasprintf-devel=0.22.5=h661eb56_2
80
+ - libblas=3.9.0=16_linux64_mkl
81
+ - libbrotlicommon=1.1.0=hd590300_1
82
+ - libbrotlidec=1.1.0=hd590300_1
83
+ - libbrotlienc=1.1.0=hd590300_1
84
+ - libcap=2.67=he9d0100_0
85
+ - libcblas=3.9.0=16_linux64_mkl
86
+ - libclang=15.0.7=default_h127d8a8_5
87
+ - libclang13=15.0.7=default_h5d6823c_5
88
+ - libcublas=12.1.0.26=0
89
+ - libcufft=11.0.2.4=0
90
+ - libcufile=1.10.0.4=0
91
+ - libcups=2.3.3=h36d4200_3
92
+ - libcurand=10.3.6.39=0
93
+ - libcusolver=11.4.4.55=0
94
+ - libcusparse=12.0.2.55=0
95
+ - libdb=6.2.32=h9c3ff4c_0
96
+ - libdeflate=1.17=h0b41bf4_0
97
+ - libedit=3.1.20191231=he28a2e2_2
98
+ - libevent=2.1.10=h28343ad_4
99
+ - libexpat=2.6.2=h59595ed_0
100
+ - libffi=3.4.2=h7f98852_5
101
+ - libflac=1.4.3=h59595ed_0
102
+ - libgcc-ng=14.1.0=h77fa898_0
103
+ - libgcrypt=1.11.0=h4ab18f5_0
104
+ - libgettextpo=0.22.5=h59595ed_2
105
+ - libgettextpo-devel=0.22.5=h59595ed_2
106
+ - libgfortran-ng=14.1.0=h69a702a_0
107
+ - libgfortran5=14.1.0=hc5f4f2c_0
108
+ - libglib=2.80.2=hf974151_0
109
+ - libgomp=14.1.0=h77fa898_0
110
+ - libgpg-error=1.50=h4f305b6_0
111
+ - libhwloc=2.9.1=hd6dc26d_0
112
+ - libiconv=1.17=hd590300_2
113
+ - libjpeg-turbo=2.0.0=h9bf148f_0
114
+ - liblapack=3.9.0=16_linux64_mkl
115
+ - liblapacke=3.9.0=16_linux64_mkl
116
+ - libllvm15=15.0.7=hadd5161_1
117
+ - libnpp=12.0.2.50=0
118
+ - libnsl=2.0.1=hd590300_0
119
+ - libnvjitlink=12.1.105=0
120
+ - libnvjpeg=12.1.1.14=0
121
+ - libogg=1.3.5=h4ab18f5_0
122
+ - libopus=1.3.1=h7f98852_1
123
+ - libpng=1.6.43=h2797004_0
124
+ - libpq=15.3=hbcd7760_1
125
+ - libsndfile=1.2.2=hc60ed4a_1
126
+ - libsqlite=3.46.0=hde9e2c9_0
127
+ - libstdcxx-ng=14.1.0=hc0a3c3a_0
128
+ - libsystemd0=253=h8c4010b_1
129
+ - libtiff=4.5.0=h6adf6a1_2
130
+ - libtool=2.4.7=h27087fc_0
131
+ - libudev1=253=h0b41bf4_1
132
+ - libuuid=2.38.1=h0b41bf4_0
133
+ - libvorbis=1.3.7=h9c3ff4c_0
134
+ - libwebp-base=1.4.0=hd590300_0
135
+ - libxcb=1.13=h7f98852_1004
136
+ - libxcrypt=4.4.36=hd590300_1
137
+ - libxkbcommon=1.5.0=h79f4944_1
138
+ - libxml2=2.10.3=hca2bb57_4
139
+ - libzlib=1.2.13=h4ab18f5_6
140
+ - llvm-openmp=15.0.7=h0cdce71_0
141
+ - lz4-c=1.9.4=hcb278e6_0
142
+ - markupsafe=2.1.5=py39hd1e30aa_0
143
+ - matplotlib=3.8.4=py39hf3d152e_2
144
+ - matplotlib-base=3.8.4=py39h10d1fc8_2
145
+ - mkl=2022.1.0=h84fe81f_915
146
+ - mkl-devel=2022.1.0=ha770c72_916
147
+ - mkl-include=2022.1.0=h84fe81f_915
148
+ - mpc=1.3.1=hfe3b2da_0
149
+ - mpfr=4.2.1=h9458935_1
150
+ - mpg123=1.32.6=h59595ed_0
151
+ - mpmath=1.3.0=pyhd8ed1ab_0
152
+ - munkres=1.1.4=pyh9f0ad1d_0
153
+ - mysql-common=8.0.33=hf1915f5_6
154
+ - mysql-libs=8.0.33=hca2cd23_6
155
+ - ncurses=6.5=h59595ed_0
156
+ - nettle=3.6=he412f7d_0
157
+ - networkx=3.2.1=pyhd8ed1ab_0
158
+ - nspr=4.35=h27087fc_0
159
+ - nss=3.100=hca3bf56_0
160
+ - numpy=2.0.0=py39ha0965c0_0
161
+ - openh264=2.1.1=h780b84a_0
162
+ - openjpeg=2.5.0=hfec8fc6_2
163
+ - openssl=3.1.6=h4ab18f5_0
164
+ - packaging=24.1=pyhd8ed1ab_0
165
+ - pcre2=10.43=hcad00b1_0
166
+ - pillow=9.4.0=py39h2320bf1_1
167
+ - pip=24.0=pyhd8ed1ab_0
168
+ - pixman=0.43.2=h59595ed_0
169
+ - ply=3.11=pyhd8ed1ab_2
170
+ - pthread-stubs=0.4=h36c2ea0_1001
171
+ - pulseaudio=16.1=hcb278e6_3
172
+ - pulseaudio-client=16.1=h5195f5e_3
173
+ - pulseaudio-daemon=16.1=ha8d29e2_3
174
+ - pycparser=2.22=pyhd8ed1ab_0
175
+ - pyparsing=3.1.2=pyhd8ed1ab_0
176
+ - pyqt=5.15.9=py39h52134e7_5
177
+ - pyqt5-sip=12.12.2=py39h3d6467e_5
178
+ - pysocks=1.7.1=pyha2e5f31_6
179
+ - python=3.9.18=h0755675_0_cpython
180
+ - python-dateutil=2.9.0=pyhd8ed1ab_0
181
+ - python_abi=3.9=4_cp39
182
+ - pytorch=2.3.1=py3.9_cuda12.1_cudnn8.9.2_0
183
+ - pytorch-cuda=12.1=ha16c6d3_5
184
+ - pytorch-mutex=1.0=cuda
185
+ - pyyaml=6.0.1=py39hd1e30aa_1
186
+ - qt-main=5.15.8=h5d23da1_6
187
+ - readline=8.2=h8228510_1
188
+ - requests=2.32.3=pyhd8ed1ab_0
189
+ - setuptools=70.1.1=pyhd8ed1ab_0
190
+ - shapely=2.0.4=py39h5a575da_1
191
+ - sip=6.7.12=py39h3d6467e_0
192
+ - six=1.16.0=pyh6c4a22f_0
193
+ - sympy=1.12.1=pypyh2585a3b_103
194
+ - tbb=2021.9.0=hf52228f_0
195
+ - tk=8.6.13=noxft_h4845f30_101
196
+ - toml=0.10.2=pyhd8ed1ab_0
197
+ - tomli=2.0.1=pyhd8ed1ab_0
198
+ - torchaudio=2.3.1=py39_cu121
199
+ - torchtriton=2.3.1=py39
200
+ - torchvision=0.18.1=py39_cu121
201
+ - tornado=6.4.1=py39hd3abc70_0
202
+ - tqdm=4.66.4=pyhd8ed1ab_0
203
+ - typing_extensions=4.12.2=pyha770c72_0
204
+ - tzdata=2024a=h0c530f3_0
205
+ - unicodedata2=15.1.0=py39hd1e30aa_0
206
+ - urllib3=2.2.2=pyhd8ed1ab_1
207
+ - wheel=0.43.0=pyhd8ed1ab_1
208
+ - xcb-util=0.4.0=h516909a_0
209
+ - xcb-util-image=0.4.0=h166bdaf_0
210
+ - xcb-util-keysyms=0.4.0=h516909a_0
211
+ - xcb-util-renderutil=0.3.9=h166bdaf_0
212
+ - xcb-util-wm=0.4.1=h516909a_0
213
+ - xkeyboard-config=2.38=h0b41bf4_0
214
+ - xorg-kbproto=1.0.7=h7f98852_1002
215
+ - xorg-libice=1.1.1=hd590300_0
216
+ - xorg-libsm=1.2.4=h7391055_0
217
+ - xorg-libx11=1.8.4=h0b41bf4_0
218
+ - xorg-libxau=1.0.11=hd590300_0
219
+ - xorg-libxdmcp=1.1.3=h7f98852_0
220
+ - xorg-libxext=1.3.4=h0b41bf4_2
221
+ - xorg-libxrender=0.9.10=h7f98852_1003
222
+ - xorg-renderproto=0.11.1=h7f98852_1002
223
+ - xorg-xextproto=7.3.0=h0b41bf4_1003
224
+ - xorg-xproto=7.0.31=h7f98852_1007
225
+ - xz=5.2.6=h166bdaf_0
226
+ - yaml=0.2.5=h7f98852_2
227
+ - zipp=3.19.2=pyhd8ed1ab_0
228
+ - zlib=1.2.13=h4ab18f5_6
229
+ - zstandard=0.22.0=py39h81c9582_1
230
+ - zstd=1.5.6=ha6fb4c9_0
231
+ - pip:
232
+ - lsq-ellipse==2.2.1
233
+ - opencv-python==4.10.0.84
234
+ - scipy==1.13.1
235
+ prefix: /home/mgutierrez/anaconda3/envs/PnLCalib
README.md CHANGED
@@ -1,10 +1,134 @@
1
- ---
2
- title: PnLCalib
3
- emoji: 🚀
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PnLCalib API
2
+
3
+ > API REST pour la calibration de caméras à partir de lignes de terrain de football
4
+
5
+ ## À propos de ce projet
6
+
7
+ Cette API est basée sur le travail de recherche original de **SoccerNet Camera Calibration Challenge**. Elle transforme les algorithmes de calibration existants en une API REST accessible et facile à utiliser.
8
+
9
+ ### Travail original
10
+ 📍 **Repository source** : [SoccerNet Camera Calibration](https://github.com/SoccerNet/sn-calibration) [Marc Gutiérrez-Pérez](https://github.com/mguti97/PnLCalib)
11
+ 📖 **Paper** : SoccerNet Camera Calibration Challenge
12
+ 👥 **Auteurs** : Équipe SoccerNet
13
+
14
+ ## Fonctionnalités
15
+
16
+ ✅ Calibration automatique de caméras à partir d'images de terrain de football
17
+ ✅ API REST avec FastAPI
18
+ ✅ Support des formats d'image : JPG, PNG
19
+
20
+ ## Installation locale
21
+
22
+ ```bash
23
+ # Cloner le repository
24
+ git clone https://github.com/2nzi/PnLCalib.git
25
+ cd PnLCalib
26
+
27
+ # Installer les dépendances
28
+ pip install -r requirements.txt
29
+
30
+ # Lancer l'API
31
+ python run_api.py
32
+ ```
33
+
34
+ L'API sera accessible sur : http://localhost:8000
35
+
36
+ ## Utilisation
37
+
38
+ ### Endpoint principal : `/calibrate`
39
+
40
+ **POST** `/calibrate` - Calibrer une caméra à partir d'une image et de lignes du terrain
41
+
42
+ **Paramètres :**
43
+ - `image` : Fichier image (multipart/form-data)
44
+ - `lines_data` : JSON des lignes du terrain (string)
45
+
46
+ ### Exemple d'utilisation
47
+
48
+ #### Avec JavaScript/Fetch
49
+ ```javascript
50
+ const formData = new FormData();
51
+ formData.append('image', imageFile);
52
+ formData.append('lines_data', JSON.stringify({
53
+ "Big rect. right top": [
54
+ {"x": 1342.88, "y": 1076.99},
55
+ {"x": 1484.74, "y": 906.37}
56
+ ],
57
+ "Big rect. right main": [
58
+ {"x": 1484.74, "y": 906.37},
59
+ {"x": 1049.62, "y": 748.02}
60
+ ],
61
+ "Circle central": [
62
+ {"x": 1580.73, "y": 269.84},
63
+ {"x": 1533.83, "y": 288.86}
64
+ ]
65
+ // ... autres lignes
66
+ }));
67
+
68
+ const response = await fetch('http://localhost:8000/calibrate', {
69
+ method: 'POST',
70
+ body: formData
71
+ });
72
+
73
+ const result = await response.json();
74
+ console.log('Paramètres de calibration:', result.camera_parameters);
75
+ ```
76
+
77
+ #### Avec curl
78
+ ```bash
79
+ curl -X POST "http://localhost:8000/calibrate" \
80
81
+ -F 'lines_data={"Big rect. right top":[{"x":1342.88,"y":1076.99}]}'
82
+ ```
83
+
84
+ ### Format de réponse
85
+
86
+ ```json
87
+ {
88
+ "status": "success",
89
+ "camera_parameters": {
90
+ "pan_degrees": -45.2,
91
+ "tilt_degrees": 12.8,
92
+ "roll_degrees": 1.2,
93
+ "position_meters": [10.5, 20.3, 5.8],
94
+ "x_focal_length": 1200.5,
95
+ "y_focal_length": 1201.2,
96
+ "principal_point": [960, 540]
97
+ },
98
+ "input_lines": { /* lignes validées */ },
99
+ "message": "Calibration réussie"
100
+ }
101
+ ```
102
+
103
+ ## Documentation
104
+
105
+ Une fois l'API lancée, accédez à la documentation interactive :
106
+ - **Swagger UI** : http://localhost:8000/docs
107
+ - **ReDoc** : http://localhost:8000/redoc
108
+
109
+ ## Health Check
110
+
111
+ ```bash
112
+ curl http://localhost:8000/health
113
+ ```
114
+
115
+ ## Support des lignes de terrain
116
+
117
+ L'API accepte ces types de lignes de terrain :
118
+ - `Big rect. right/left top/main/bottom`
119
+ - `Small rect. right/left top/main`
120
+ - `Circle right/left/central`
121
+ - `Side line bottom/left`
122
+ - `Middle line`
123
+ ...
124
+
125
+ Chaque ligne est définie par une liste de points avec coordonnées `x` et `y`.
126
+
127
+ ## Crédits
128
+
129
+ Basé sur le travail original de l'équipe SoccerNet pour le Camera Calibration Challenge [Marc Gutiérrez-Pérez](https://github.com/mguti97/PnLCalib).
130
+ Transformé en API REST par [2nzi](https://github.com/2nzi).
131
+
132
+ ## Licence
133
+
134
+ Voir [LICENSE](LICENSE) - Basé sur la licence du projet original SoccerNet.
api.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from typing import Dict, List, Any
5
+ import json
6
+ import tempfile
7
+ import os
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+ from get_camera_params import get_camera_parameters
12
+
13
+ app = FastAPI(
14
+ title="Football Vision Calibration API",
15
+ description="API pour la calibration de caméras à partir de lignes de terrain de football",
16
+ version="1.0.0"
17
+ )
18
+
19
+ # Configuration CORS pour autoriser les requêtes depuis le frontend
20
+ app.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["*"], # En production, spécifiez les domaines autorisés
23
+ allow_credentials=True,
24
+ allow_methods=["*"],
25
+ allow_headers=["*"],
26
+ )
27
+
28
+ # Modèles Pydantic pour la validation des données
29
+ class Point(BaseModel):
30
+ x: float
31
+ y: float
32
+
33
+ class LinePolygon(BaseModel):
34
+ points: List[Point]
35
+
36
+ class CalibrationRequest(BaseModel):
37
+ lines: Dict[str, List[Point]]
38
+
39
+ class CalibrationResponse(BaseModel):
40
+ status: str
41
+ camera_parameters: Dict[str, Any]
42
+ input_lines: Dict[str, List[Point]]
43
+ message: str
44
+
45
+ @app.get("/")
46
+ async def root():
47
+ return {
48
+ "message": "Football Vision Calibration API",
49
+ "version": "1.0.0",
50
+ "endpoints": {
51
+ "/calibrate": "POST - Calibrer une caméra à partir d'une image et de lignes",
52
+ "/health": "GET - Vérifier l'état de l'API"
53
+ }
54
+ }
55
+
56
+ @app.get("/health")
57
+ async def health_check():
58
+ return {"status": "healthy", "message": "API is running"}
59
+
60
+ @app.post("/calibrate", response_model=CalibrationResponse)
61
+ async def calibrate_camera(
62
+ image: UploadFile = File(..., description="Image du terrain de football"),
63
+ lines_data: str = Form(..., description="JSON des lignes du terrain")
64
+ ):
65
+ """
66
+ Calibrer une caméra à partir d'une image et des lignes du terrain.
67
+
68
+ Args:
69
+ image: Image du terrain de football (formats: jpg, jpeg, png)
70
+ lines_data: JSON contenant les lignes du terrain au format:
71
+ {"nom_ligne": [{"x": float, "y": float}, ...], ...}
72
+
73
+ Returns:
74
+ Paramètres de calibration de la caméra et lignes d'entrée
75
+ """
76
+ try:
77
+ # Validation du format d'image
78
+ if not image.content_type.startswith('image/'):
79
+ raise HTTPException(status_code=400, detail="Le fichier doit être une image")
80
+
81
+ # Parse des données de lignes
82
+ try:
83
+ lines_dict = json.loads(lines_data)
84
+ except json.JSONDecodeError:
85
+ raise HTTPException(status_code=400, detail="Format JSON invalide pour les lignes")
86
+
87
+ # Validation de la structure des lignes
88
+ validated_lines = {}
89
+ for line_name, points in lines_dict.items():
90
+ if not isinstance(points, list):
91
+ raise HTTPException(
92
+ status_code=400,
93
+ detail=f"Les points de la ligne '{line_name}' doivent être une liste"
94
+ )
95
+
96
+ validated_points = []
97
+ for i, point in enumerate(points):
98
+ if not isinstance(point, dict) or 'x' not in point or 'y' not in point:
99
+ raise HTTPException(
100
+ status_code=400,
101
+ detail=f"Point {i} de la ligne '{line_name}' doit avoir les clés 'x' et 'y'"
102
+ )
103
+ try:
104
+ validated_points.append({
105
+ "x": float(point['x']),
106
+ "y": float(point['y'])
107
+ })
108
+ except (ValueError, TypeError):
109
+ raise HTTPException(
110
+ status_code=400,
111
+ detail=f"Coordonnées invalides pour le point {i} de la ligne '{line_name}'"
112
+ )
113
+
114
+ validated_lines[line_name] = validated_points
115
+
116
+ # Sauvegarde temporaire de l'image
117
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f".{image.filename.split('.')[-1]}") as temp_file:
118
+ content = await image.read()
119
+ temp_file.write(content)
120
+ temp_image_path = temp_file.name
121
+
122
+ try:
123
+ # Validation de l'image
124
+ pil_image = Image.open(temp_image_path)
125
+ pil_image.verify() # Vérification de l'intégrité de l'image
126
+
127
+ # Calibration de la caméra
128
+ camera_params = get_camera_parameters(temp_image_path, validated_lines)
129
+
130
+ # Formatage de la réponse
131
+ response = CalibrationResponse(
132
+ status="success",
133
+ camera_parameters=camera_params,
134
+ input_lines=validated_lines,
135
+ message="Calibration réussie"
136
+ )
137
+
138
+ return response
139
+
140
+ except Exception as e:
141
+ raise HTTPException(
142
+ status_code=500,
143
+ detail=f"Erreur lors de la calibration: {str(e)}"
144
+ )
145
+
146
+ finally:
147
+ # Nettoyage du fichier temporaire
148
+ if os.path.exists(temp_image_path):
149
+ os.unlink(temp_image_path)
150
+
151
+ except HTTPException:
152
+ raise
153
+ except Exception as e:
154
+ raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
155
+
156
+ # if __name__ == "__main__":
157
+ # import uvicorn
158
+ # uvicorn.run(app, host="0.0.0.0", port=8000)
159
+
160
+ # Ajoutez ceci à la place :
161
+ # Point d'entrée pour Vercel
162
+ app_instance = app
config.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.yaml
2
+ # weights_kp: "models/SV_FT_TSWC_kp"
3
+ # weights_line: "models/SV_FT_TSWC_lines"
4
+ # pnl_refine: true
5
+ # device: "cuda:0"
6
+ # input_path: "examples/input/"
7
+ # output_path: "examples/output/"
8
+ # name: "FootDrone.jpg"
9
+ # input_type: "image"
10
+ # display: true
11
+
12
+ # # Start with low thresholds
13
+ # thresholds:
14
+ # - kp_threshold: 0.15
15
+ # line_threshold: 0.15
16
+
17
+
18
+ weights_kp: "models/SV_FT_TSWC_kp"
19
+ weights_line: "models/SV_FT_TSWC_lines"
20
+ device: "cuda:0"
21
+ input_path: "examples/input/"
22
+ output_path: "examples/output/"
23
+ name: "FootDrone.mp4"
24
+ input_type: "video"
25
+ display: true
26
+ pnl_refine: true
27
+ frame_step: 5 # Traiter une image sur 5
28
+ two_pass: true
29
+
30
+ # Start with low thresholds
31
+ thresholds:
32
+ - kp_threshold: 0.15
33
+ line_threshold: 0.15
config/field_config.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration file for football field elements (lines and keypoints)
3
+ All measurements are in meters
4
+ Field dimensions: 105m x 68m
5
+ Origin (0,0) is at top-left corner
6
+ """
7
+
8
+ # Lines configuration
9
+ LINES = {
10
+ 1: {"name": "Big rect. left bottom", "description": "Surface de réparation gauche - ligne basse"},
11
+ 2: {"name": "Big rect. left main", "description": "Surface de réparation gauche - ligne parallèle"},
12
+ 3: {"name": "Big rect. left top", "description": "Surface de réparation gauche - ligne haute"},
13
+ 4: {"name": "Big rect. right bottom", "description": "Surface de réparation droite - ligne basse"},
14
+ 5: {"name": "Big rect. right main", "description": "Surface de réparation droite - ligne parallèle"},
15
+ 6: {"name": "Big rect. right top", "description": "Surface de réparation droite - ligne haute"},
16
+ 7: {"name": "Goal left crossbar", "description": "Barre transversale but gauche"},
17
+ 8: {"name": "Goal left post left", "description": "Poteau gauche but gauche"},
18
+ 9: {"name": "Goal left post right", "description": "Poteau droit but gauche"},
19
+ 10: {"name": "Goal right crossbar", "description": "Barre transversale but droit"},
20
+ 11: {"name": "Goal right post left", "description": "Poteau gauche but droit"},
21
+ 12: {"name": "Goal right post right", "description": "Poteau droit but droit"},
22
+ 13: {"name": "Middle line", "description": "Ligne médiane"},
23
+ 14: {"name": "Side line bottom", "description": "Ligne de but"},
24
+ 15: {"name": "Side line left", "description": "Ligne de touche gauche"},
25
+ 16: {"name": "Side line right", "description": "Ligne de touche droite"},
26
+ 17: {"name": "Side line top", "description": "Ligne de but opposée"},
27
+ 18: {"name": "Small rect. left bottom", "description": "Petite surface gauche - ligne basse"},
28
+ 19: {"name": "Small rect. left main", "description": "Petite surface gauche - ligne parallèle"},
29
+ 20: {"name": "Small rect. left top", "description": "Petite surface gauche - ligne haute"},
30
+ 21: {"name": "Small rect. right bottom", "description": "Petite surface droite - ligne basse"},
31
+ 22: {"name": "Small rect. right main", "description": "Petite surface droite - ligne parallèle"},
32
+ 23: {"name": "Small rect. right top", "description": "Petite surface droite - ligne haute"}
33
+ }
34
+
35
+ # Keypoints configuration
36
+ KEYPOINTS = {
37
+ # Points principaux (1-30)
38
+ 1: {"name": "Top Left Corner", "coords": [0.0, 0.0], "description": "Coin supérieur gauche"},
39
+ 2: {"name": "Top Middle", "coords": [52.5, 0.0], "description": "Milieu ligne du haut"},
40
+ 3: {"name": "Top Right Corner", "coords": [105.0, 0.0], "description": "Coin supérieur droit"},
41
+ 4: {"name": "Left Big Box Top", "coords": [0.0, 16.5], "description": "Surface gauche haut"},
42
+ 5: {"name": "Left Big Box Main", "coords": [16.5, 16.5], "description": "Surface gauche principale"},
43
+ 6: {"name": "Right Big Box Main", "coords": [88.5, 16.5], "description": "Surface droite principale"},
44
+ 7: {"name": "Right Big Box Top", "coords": [105.0, 16.5], "description": "Surface droite haut"},
45
+ 8: {"name": "Left Small Box Top", "coords": [0.0, 5.5], "description": "Petite surface gauche haut"},
46
+ 9: {"name": "Left Small Box Main", "coords": [5.5, 5.5], "description": "Petite surface gauche principale"},
47
+ 10: {"name": "Right Small Box Main", "coords": [99.5, 5.5], "description": "Petite surface droite principale"},
48
+ 11: {"name": "Right Small Box Top", "coords": [105.0, 5.5], "description": "Petite surface droite haut"},
49
+ 12: {"name": "Left Goal Crossbar Right", "coords": [3.66, 0.0], "description": "Barre transversale gauche - droite"},
50
+ 13: {"name": "Left Goal Post Right", "coords": [3.66, 2.44], "description": "Poteau droit but gauche"},
51
+ 14: {"name": "Right Goal Post Left", "coords": [101.34, 2.44], "description": "Poteau gauche but droit"},
52
+ 15: {"name": "Right Goal Crossbar Left", "coords": [101.34, 0.0], "description": "Barre transversale droite - gauche"},
53
+ 16: {"name": "Left Goal Crossbar Left", "coords": [0.0, 0.0], "description": "Barre transversale gauche - gauche"},
54
+ 17: {"name": "Left Goal Post Left", "coords": [0.0, 2.44], "description": "Poteau gauche but gauche"},
55
+ 18: {"name": "Right Goal Post Right", "coords": [105.0, 2.44], "description": "Poteau droit but droit"},
56
+ 19: {"name": "Right Goal Crossbar Right", "coords": [105.0, 0.0], "description": "Barre transversale droite - droite"},
57
+ 20: {"name": "Left Small Box Bottom", "coords": [0.0, 0.0], "description": "Petite surface gauche bas"},
58
+ 21: {"name": "Left Small Box Bottom Main", "coords": [5.5, 0.0], "description": "Petite surface gauche bas principale"},
59
+ 22: {"name": "Right Small Box Bottom Main", "coords": [99.5, 0.0], "description": "Petite surface droite bas principale"},
60
+ 23: {"name": "Right Small Box Bottom", "coords": [105.0, 0.0], "description": "Petite surface droite bas"},
61
+ 24: {"name": "Left Big Box Bottom", "coords": [0.0, 0.0], "description": "Surface gauche bas"},
62
+ 25: {"name": "Bottom Middle", "coords": [52.5, 0.0], "description": "Milieu ligne du bas"},
63
+ 26: {"name": "Right Big Box Bottom", "coords": [105.0, 0.0], "description": "Surface droite bas"},
64
+ 27: {"name": "Left Big Box Bottom Main", "coords": [16.5, 0.0], "description": "Surface gauche bas principale"},
65
+ 28: {"name": "Bottom Middle Top", "coords": [52.5, 68.0], "description": "Milieu ligne opposée"},
66
+ 29: {"name": "Right Big Box Bottom Main", "coords": [88.5, 0.0], "description": "Surface droite bas principale"},
67
+ 30: {"name": "Left Penalty Spot", "coords": [11.0, 34.0], "description": "Point de penalty gauche"},
68
+
69
+ # Points auxiliaires (31-57)
70
+ 31: {"name": "Left Box Aux 1", "coords": [16.5, 20.0], "description": "Point auxiliaire surface gauche 1"},
71
+ 32: {"name": "Center Circle Left", "coords": [43.35, 34.0], "description": "Rond central gauche"},
72
+ 33: {"name": "Center Circle Right", "coords": [61.65, 34.0], "description": "Rond central droit"},
73
+ 34: {"name": "Right Box Aux 1", "coords": [88.5, 20.0], "description": "Point auxiliaire surface droite 1"},
74
+ 35: {"name": "Left Box Aux 2", "coords": [16.5, 48.0], "description": "Point auxiliaire surface gauche 2"},
75
+ 36: {"name": "Right Box Aux 2", "coords": [88.5, 48.0], "description": "Point auxiliaire surface droite 2"},
76
+ 37: {"name": "Center Circle Top Left", "coords": [43.35, 24.85], "description": "Rond central haut gauche"},
77
+ 38: {"name": "Center Circle Top Right", "coords": [61.65, 24.85], "description": "Rond central haut droit"},
78
+ 39: {"name": "Center Circle Bottom Left", "coords": [43.35, 43.15], "description": "Rond central bas gauche"},
79
+ 40: {"name": "Center Circle Bottom Right", "coords": [61.65, 43.15], "description": "Rond central bas droit"},
80
+ 41: {"name": "Center Circle Left Inner", "coords": [46.03, 34.0], "description": "Rond central intérieur gauche"},
81
+ 42: {"name": "Center Circle Right Inner", "coords": [58.97, 34.0], "description": "Rond central intérieur droit"},
82
+ 43: {"name": "Center Circle Top Inner", "coords": [52.5, 27.53], "description": "Rond central intérieur haut"},
83
+ 44: {"name": "Center Circle Bottom Inner", "coords": [52.5, 40.47], "description": "Rond central intérieur bas"},
84
+ 45: {"name": "Left Penalty Arc Left", "coords": [19.99, 32.29], "description": "Arc penalty gauche - point gauche"},
85
+ 46: {"name": "Left Penalty Arc Right", "coords": [19.99, 35.71], "description": "Arc penalty gauche - point droit"},
86
+ 47: {"name": "Right Penalty Arc Left", "coords": [85.01, 32.29], "description": "Arc penalty droit - point gauche"},
87
+ 48: {"name": "Right Penalty Arc Right", "coords": [85.01, 35.71], "description": "Arc penalty droit - point droit"},
88
+ 49: {"name": "Left Penalty Arc Top", "coords": [16.5, 34.0], "description": "Arc penalty gauche - point haut"},
89
+ 50: {"name": "Right Penalty Arc Top", "coords": [88.5, 34.0], "description": "Arc penalty droit - point haut"},
90
+ 51: {"name": "Left Penalty Area Center", "coords": [16.5, 34.0], "description": "Centre surface de réparation gauche"},
91
+ 52: {"name": "Right Penalty Area Center", "coords": [88.5, 34.0], "description": "Centre surface de réparation droite"},
92
+ 53: {"name": "Right Penalty Spot", "coords": [94.0, 34.0], "description": "Point de penalty droit"},
93
+ 54: {"name": "Center Spot", "coords": [52.5, 34.0], "description": "Point central"},
94
+ 55: {"name": "Left Box Center", "coords": [16.5, 34.0], "description": "Centre surface gauche"},
95
+ 56: {"name": "Right Box Center", "coords": [88.5, 34.0], "description": "Centre surface droite"},
96
+ 57: {"name": "Center Circle Center", "coords": [52.5, 34.0], "description": "Centre rond central"}
97
+ }
98
+
99
+ # Field dimensions
100
+ FIELD_DIMENSIONS = {
101
+ "length": 105.0, # meters
102
+ "width": 68.0, # meters
103
+ "center_circle_radius": 9.15,
104
+ "penalty_area_length": 16.5,
105
+ "penalty_area_width": 40.32,
106
+ "goal_area_length": 5.5,
107
+ "goal_area_width": 18.32,
108
+ "penalty_spot_dist": 11.0,
109
+ "goal_height": 2.44,
110
+ "goal_width": 7.32
111
+ }
112
+
113
+ # Line intersections that form keypoints
114
+ KEYPOINT_LINE_PAIRS = [
115
+ ["Side line top", "Side line left"],
116
+ ["Side line top", "Middle line"],
117
+ ["Side line right", "Side line top"],
118
+ ["Side line left", "Big rect. left top"],
119
+ ["Big rect. left top", "Big rect. left main"],
120
+ ["Big rect. right top", "Big rect. right main"],
121
+ ["Side line right", "Big rect. right top"],
122
+ ["Side line left", "Small rect. left top"],
123
+ ["Small rect. left top", "Small rect. left main"],
124
+ ["Small rect. right top", "Small rect. right main"],
125
+ ["Side line right", "Small rect. right top"],
126
+ ["Goal left crossbar", "Goal left post right"],
127
+ ["Side line left", "Goal left post right"],
128
+ ["Side line right", "Goal right post left"],
129
+ ["Goal right crossbar", "Goal right post left"],
130
+ ["Goal left crossbar", "Goal left post left"],
131
+ ["Side line left", "Goal left post left"],
132
+ ["Side line right", "Goal right post right"],
133
+ ["Goal right crossbar", "Goal right post right"],
134
+ ["Side line left", "Small rect. left bottom"],
135
+ ["Small rect. left bottom", "Small rect. left main"],
136
+ ["Small rect. right bottom", "Small rect. right main"],
137
+ ["Side line right", "Small rect. right bottom"],
138
+ ["Side line left", "Big rect. left bottom"],
139
+ ["Big rect. left bottom", "Big rect. left main"],
140
+ ["Big rect. right bottom", "Big rect. right main"],
141
+ ["Side line right", "Big rect. right bottom"],
142
+ ["Side line left", "Side line bottom"],
143
+ ["Side line right", "Side line bottom"]
144
+ ]
145
+
146
+ # Auxiliary keypoint pairs
147
+ KEYPOINT_AUX_PAIRS = [
148
+ ["Small rect. left main", "Side line top"],
149
+ ["Big rect. left main", "Side line top"],
150
+ ["Big rect. right main", "Side line top"],
151
+ ["Small rect. right main", "Side line top"],
152
+ ["Small rect. left main", "Big rect. left top"],
153
+ ["Big rect. right top", "Small rect. right main"],
154
+ ["Small rect. left top", "Big rect. left main"],
155
+ ["Small rect. right top", "Big rect. right main"],
156
+ ["Small rect. left bottom", "Big rect. left main"],
157
+ ["Small rect. right bottom", "Big rect. right main"],
158
+ ["Small rect. left main", "Big rect. left bottom"],
159
+ ["Small rect. right main", "Big rect. right bottom"],
160
+ ["Small rect. left main", "Side line bottom"],
161
+ ["Big rect. left main", "Side line bottom"],
162
+ ["Big rect. right main", "Side line bottom"],
163
+ ["Small rect. right main", "Side line bottom"]
164
+ ]
config/hrnetv2_w48.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ IMAGE_SIZE: [960, 540]
3
+ NUM_JOINTS: 58
4
+ PRETRAIN: ''
5
+ EXTRA:
6
+ FINAL_CONV_KERNEL: 1
7
+ STAGE1:
8
+ NUM_MODULES: 1
9
+ NUM_BRANCHES: 1
10
+ BLOCK: BOTTLENECK
11
+ NUM_BLOCKS: [4]
12
+ NUM_CHANNELS: [64]
13
+ FUSE_METHOD: SUM
14
+ STAGE2:
15
+ NUM_MODULES: 1
16
+ NUM_BRANCHES: 2
17
+ BLOCK: BASIC
18
+ NUM_BLOCKS: [4, 4]
19
+ NUM_CHANNELS: [48, 96]
20
+ FUSE_METHOD: SUM
21
+ STAGE3:
22
+ NUM_MODULES: 4
23
+ NUM_BRANCHES: 3
24
+ BLOCK: BASIC
25
+ NUM_BLOCKS: [4, 4, 4]
26
+ NUM_CHANNELS: [48, 96, 192]
27
+ FUSE_METHOD: SUM
28
+ STAGE4:
29
+ NUM_MODULES: 3
30
+ NUM_BRANCHES: 4
31
+ BLOCK: BASIC
32
+ NUM_BLOCKS: [4, 4, 4, 4]
33
+ NUM_CHANNELS: [48, 96, 192, 384]
34
+ FUSE_METHOD: SUM
35
+
config/hrnetv2_w48_l.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ IMAGE_SIZE: [960, 540]
3
+ NUM_JOINTS: 24
4
+ PRETRAIN: ''
5
+ EXTRA:
6
+ FINAL_CONV_KERNEL: 1
7
+ STAGE1:
8
+ NUM_MODULES: 1
9
+ NUM_BRANCHES: 1
10
+ BLOCK: BOTTLENECK
11
+ NUM_BLOCKS: [4]
12
+ NUM_CHANNELS: [64]
13
+ FUSE_METHOD: SUM
14
+ STAGE2:
15
+ NUM_MODULES: 1
16
+ NUM_BRANCHES: 2
17
+ BLOCK: BASIC
18
+ NUM_BLOCKS: [4, 4]
19
+ NUM_CHANNELS: [48, 96]
20
+ FUSE_METHOD: SUM
21
+ STAGE3:
22
+ NUM_MODULES: 4
23
+ NUM_BRANCHES: 3
24
+ BLOCK: BASIC
25
+ NUM_BLOCKS: [4, 4, 4]
26
+ NUM_CHANNELS: [48, 96, 192]
27
+ FUSE_METHOD: SUM
28
+ STAGE4:
29
+ NUM_MODULES: 3
30
+ NUM_BRANCHES: 4
31
+ BLOCK: BASIC
32
+ NUM_BLOCKS: [4, 4, 4, 4]
33
+ NUM_CHANNELS: [48, 96, 192, 384]
34
+ FUSE_METHOD: SUM
35
+
data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Fichier vide pour marquer le dossier comme un package Python
data/line_data.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contient cam1_line_dict et cam3_line_dict
2
+
3
+ cam1_line_dict = {
4
+ "Side line top": [
5
+ {"x": 0, "y": 205.35510086780727},
6
+ {"x": 226.69196710942427, "y": 218.52089231566922},
7
+ {"x": 609.2346616065778, "y": 235.73769651671947},
8
+ {"x": 1072.7387729285256, "y": 248.90348796458142},
9
+ {"x": 1642.5047438330162, "y": 257.005513470958},
10
+ {"x": 1918.7855787476271, "y": 257.005513470958}
11
+ ],
12
+ "Big rect. right top": [
13
+ {"x": 1918.6823296329553, "y": 327.02322511812105},
14
+ {"x": 1761.8195845349164, "y": 327.02322511812105}
15
+ ],
16
+ "Big rect. right main": [
17
+ {"x": 1761.8195845349164, "y": 327.02322511812105},
18
+ {"x": 1774.5760335197388, "y": 409.7284234765342},
19
+ {"x": 1790.8194197631253, "y": 526.9377258021154},
20
+ {"x": 1809.8670388107444, "y": 671.5805829449724}
21
+ ],
22
+ "Big rect. right bottom": [
23
+ {"x": 1809.8670388107444, "y": 671.5805829449724},
24
+ {"x": 1917.4300640208282, "y": 677.1868952373312}
25
+ ],
26
+ "Circle right": [
27
+ {"x": 1774.5760335197388, "y": 409.7284234765342},
28
+ {"x": 1709.2165563955555, "y": 410.6005164997901},
29
+ {"x": 1629.0422644565579, "y": 423.6819118486273},
30
+ {"x": 1576.7546827572116, "y": 442.8679583602552},
31
+ {"x": 1568.0400858073206, "y": 459.43772580211566},
32
+ {"x": 1592.4409572670156, "y": 482.11214440676673},
33
+ {"x": 1698.7590400556867, "y": 511.7633071974644},
34
+ {"x": 1790.8194197631253, "y": 526.9377258021154}
35
+ ],
36
+ "Middle line": [
37
+ {"x": 0, "y": 308.65587267083185},
38
+ {"x": 226.69196710942427, "y": 218.52089231566922}
39
+ ],
40
+ "Circle central": [
41
+ {"x": 0, "y": 340.8997444495873},
42
+ {"x": 122.94647588765228, "y": 351.0820197481417},
43
+ {"x": 234.87016428192882, "y": 366.3554326959732},
44
+ {"x": 305.2464228934815, "y": 390.1140750592667},
45
+ {"x": 309.48595654477987, "y": 413.8727174225603},
46
+ {"x": 228.08691043985144, "y": 441.0254515520386},
47
+ {"x": 71.22416534181235, "y": 463.08704803223975},
48
+ {"x": 4.239533651298354, "y": 468.1781856815169}
49
+ ],
50
+ "Side line bottom": [
51
+ {"x": 1.3071895424836595, "y": 814.522819727929},
52
+ {"x": 636.6013071895421, "y": 827.6042150767662},
53
+ {"x": 1286.2745098039209, "y": 852.458866239557},
54
+ {"x": 1918.954248366012, "y": 889.086773216301}
55
+ ]
56
+ }
57
+
58
+ cam3_line_dict = {
59
+ "Big rect. right top": [
60
+ {"x": 1342.8861505076343, "y": 1076.997434976179},
61
+ {"x": 1484.7446330310781, "y": 906.3705391217808}
62
+ ],
63
+ "Big rect. right main": [
64
+ {"x": 1484.7446330310781, "y": 906.3705391217808},
65
+ {"x": 1049.6210183678218, "y": 748.0287797688992},
66
+ {"x": 828.6491513601493, "y": 668.8579000924583},
67
+ {"x": 349.8767728435256, "y": 500.9610345717304},
68
+ {"x": 32.736572890025556, "y": 397.21988189225624}
69
+ ],
70
+ "Big rect. right bottom": [
71
+ {"x": 32.736572890025556, "y": 397.21988189225624},
72
+ {"x": 0.3753980224568448, "y": 407.0286292126068}
73
+ ],
74
+ "Small rect. right top": [
75
+ {"x": 312.24913494809687, "y": 1075.6461846681693},
76
+ {"x": 426.66666666666663, "y": 999.9279904137233}
77
+ ],
78
+ "Small rect. right main": [
79
+ {"x": 426.66666666666663, "y": 999.9279904137233},
80
+ {"x": 0, "y": 769.079837198949}
81
+ ],
82
+ "Circle right": [
83
+ {"x": 828.6491513601493, "y": 668.8579000924583},
84
+ {"x": 821.7759602949911, "y": 612.2830792373484},
85
+ {"x": 782.8739995106773, "y": 564.5621490047902},
86
+ {"x": 722.6387053930304, "y": 529.3993583071158},
87
+ {"x": 623.5014504910696, "y": 503.02726528386006},
88
+ {"x": 494.24654853028534, "y": 492.980753655953},
89
+ {"x": 349.8767728435256, "y": 500.9610345717304}
90
+ ],
91
+ "Side line bottom": [
92
+ {"x": 2.0193824656299317, "y": 266.2605192109321},
93
+ {"x": 399.0443993689428, "y": 186.14824976426013},
94
+ {"x": 645.5533017804819, "y": 132.93313314748357},
95
+ {"x": 1001.1088573360372, "y": 53.39824942655338},
96
+ {"x": 1208.1676808654488, "y": 7.351737798646435}
97
+ ],
98
+ "Middle line": [
99
+ {"x": 645.5533017804819, "y": 132.93313314748357},
100
+ {"x": 1106.0585089650835, "y": 200.22939899146556},
101
+ {"x": 1580.7388158704541, "y": 269.8451725000601},
102
+ {"x": 1917.6527118636336, "y": 318.9857185061268}
103
+ ],
104
+ "Circle central": [
105
+ {"x": 1580.7388158704541, "y": 269.8451725000601},
106
+ {"x": 1580.7388158704541, "y": 269.8451725000601},
107
+ {"x": 1533.8366024891266, "y": 288.8643838246303},
108
+ {"x": 1441.810458698277, "y": 302.46903498742097},
109
+ {"x": 1316.3202626198458, "y": 304.5620582432349},
110
+ {"x": 1219.0653606590615, "y": 292.0039187083512},
111
+ {"x": 1135.4052299401073, "y": 274.2132210339326},
112
+ {"x": 1069.522876998931, "y": 237.5853140571884},
113
+ {"x": 1106.0585089650835, "y": 200.22939899146556}
114
+ ]
115
+ }
get_camera_params.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.utils_keypoints import KeypointsDB
2
+ from utils.utils_lines import LineKeypointsDB
3
+ from utils.utils_calib import FramebyFrameCalib
4
+ from utils.utils_heatmap import complete_keypoints
5
+ from PIL import Image
6
+ import torch
7
+ import numpy as np
8
+
9
+ # Données de lignes pour cam3
10
+ cam3_line_dict = {
11
+ "Big rect. right top": [
12
+ {"x": 1342.8861505076343, "y": 1076.997434976179},
13
+ {"x": 1484.7446330310781, "y": 906.3705391217808}
14
+ ],
15
+ "Big rect. right main": [
16
+ {"x": 1484.7446330310781, "y": 906.3705391217808},
17
+ {"x": 1049.6210183678218, "y": 748.0287797688992},
18
+ {"x": 828.6491513601493, "y": 668.8579000924583},
19
+ {"x": 349.8767728435256, "y": 500.9610345717304},
20
+ {"x": 32.736572890025556, "y": 397.21988189225624}
21
+ ],
22
+ "Big rect. right bottom": [
23
+ {"x": 32.736572890025556, "y": 397.21988189225624},
24
+ {"x": 0.3753980224568448, "y": 407.0286292126068}
25
+ ],
26
+ "Small rect. right top": [
27
+ {"x": 312.24913494809687, "y": 1075.6461846681693},
28
+ {"x": 426.66666666666663, "y": 999.9279904137233}
29
+ ],
30
+ "Small rect. right main": [
31
+ {"x": 426.66666666666663, "y": 999.9279904137233},
32
+ {"x": 0, "y": 769.079837198949}
33
+ ],
34
+ "Circle right": [
35
+ {"x": 828.6491513601493, "y": 668.8579000924583},
36
+ {"x": 821.7759602949911, "y": 612.2830792373484},
37
+ {"x": 782.8739995106773, "y": 564.5621490047902},
38
+ {"x": 722.6387053930304, "y": 529.3993583071158},
39
+ {"x": 623.5014504910696, "y": 503.02726528386006},
40
+ {"x": 494.24654853028534, "y": 492.980753655953},
41
+ {"x": 349.8767728435256, "y": 500.9610345717304}
42
+ ],
43
+ "Side line bottom": [
44
+ {"x": 2.0193824656299317, "y": 266.2605192109321},
45
+ {"x": 399.0443993689428, "y": 186.14824976426013},
46
+ {"x": 645.5533017804819, "y": 132.93313314748357},
47
+ {"x": 1001.1088573360372, "y": 53.39824942655338},
48
+ {"x": 1208.1676808654488, "y": 7.351737798646435}
49
+ ],
50
+ "Middle line": [
51
+ {"x": 645.5533017804819, "y": 132.93313314748357},
52
+ {"x": 1106.0585089650835, "y": 200.22939899146556},
53
+ {"x": 1580.7388158704541, "y": 269.8451725000601},
54
+ {"x": 1917.6527118636336, "y": 318.9857185061268}
55
+ ],
56
+ "Circle central": [
57
+ {"x": 1580.7388158704541, "y": 269.8451725000601},
58
+ {"x": 1580.7388158704541, "y": 269.8451725000601},
59
+ {"x": 1533.8366024891266, "y": 288.8643838246303},
60
+ {"x": 1441.810458698277, "y": 302.46903498742097},
61
+ {"x": 1316.3202626198458, "y": 304.5620582432349},
62
+ {"x": 1219.0653606590615, "y": 292.0039187083512},
63
+ {"x": 1135.4052299401073, "y": 274.2132210339326},
64
+ {"x": 1069.522876998931, "y": 237.5853140571884},
65
+ {"x": 1106.0585089650835, "y": 200.22939899146556},
66
+ {"x": 1139.5882364760548, "y": 189.4457791734675},
67
+ {"x": 1224.2941188289963, "y": 177.9341512664908},
68
+ {"x": 1314.2287593518718, "y": 174.79461638276985},
69
+ {"x": 1392.6601319008914, "y": 180.02717452230473},
70
+ {"x": 1465.8627462799764, "y": 190.49229080137454},
71
+ {"x": 1529.6535959531789, "y": 204.09694196416518},
72
+ {"x": 1581.9411776525253, "y": 230.2597326618396},
73
+ {"x": 1580.7388158704541, "y": 269.8451725000601}
74
+ ],
75
+ "Side line left": [
76
+ {"x": 1208.1676808654488, "y": 7.351737798646435},
77
+ {"x": 1401.9652021886754, "y": 20.565213248502545},
78
+ {"x": 1582.3573590514204, "y": 30.37625976013045},
79
+ {"x": 1679.416182580832, "y": 34.300678364781604},
80
+ {"x": 1824.5142217965183, "y": 41.23091697692868},
81
+ {"x": 1918.6318688553417, "y": 42.21202162809147}
82
+ ],
83
+ "Big rect. left bottom": [
84
+ {"x": 1401.9652021886754, "y": 20.565213248502545},
85
+ {"x": 1283.3377512082834, "y": 53.98527744204496}
86
+ ],
87
+ "Big rect. left main": [
88
+ {"x": 1283.3377512082834, "y": 53.98527744204496},
89
+ {"x": 1510.7887316004399, "y": 73.60737046530076},
90
+ {"x": 1808.8279472867146, "y": 94.21056813971936},
91
+ {"x": 1918.6318688553417, "y": 100.0971960466961}
92
+ ],
93
+ "Circle left": [
94
+ {"x": 1510.7887316004399, "y": 73.60737046530076},
95
+ {"x": 1548.0436335612244, "y": 86.36173093041702},
96
+ {"x": 1620.5926531690673, "y": 95.19167279088215},
97
+ {"x": 1681.3769668945574, "y": 97.15388209320773},
98
+ {"x": 1746.0828492474989, "y": 100.0971960466961},
99
+ {"x": 1808.8279472867146, "y": 94.21056813971936}
100
+ ],
101
+ "Small rect. left bottom": [
102
+ {"x": 1550.9848100318127, "y": 42.21202162809147},
103
+ {"x": 1582.3573590514204, "y": 30.37625976013045}
104
+ ],
105
+ "Small rect. left main": [
106
+ {"x": 1550.9848100318127, "y": 42.21202162809147},
107
+ {"x": 1918.418689198772, "y": 60.49417894940041}
108
+ ]
109
+ }
110
+
111
+ def transform_data(line_dict, width, height):
112
+ """
113
+ Transform input line dictionary to normalized coordinates.
114
+ """
115
+ transformed = {}
116
+ for line_name, points in line_dict.items():
117
+ transformed[line_name] = []
118
+ for point in points:
119
+ transformed[line_name].append({
120
+ "x": point["x"] / width,
121
+ "y": point["y"] / height
122
+ })
123
+ return transformed
124
+
125
+ def get_camera_parameters(image_path, line_dict):
126
+ """
127
+ Extract camera parameters from image and line data.
128
+
129
+ Args:
130
+ image_path (str): Path to the image file
131
+ line_dict (dict): Dictionary containing line coordinates
132
+
133
+ Returns:
134
+ dict: Camera parameters
135
+ """
136
+ # Load image
137
+ image = Image.open(image_path)
138
+ image_tensor = torch.FloatTensor(np.array(image)).permute(2, 0, 1)
139
+
140
+ # Get image dimensions
141
+ img_width, img_height = image.size
142
+
143
+ # Transform data using actual image dimensions
144
+ trans_data = transform_data(line_dict, img_width, img_height)
145
+
146
+ # Initialize databases
147
+ kp_db = KeypointsDB(trans_data, image_tensor)
148
+ ln_db = LineKeypointsDB(trans_data, image_tensor)
149
+
150
+ # Get keypoints and lines
151
+ kp_db.get_full_keypoints()
152
+ ln_db.get_lines()
153
+
154
+ kp_dict = kp_db.keypoints_final
155
+ ln_dict = ln_db.lines
156
+
157
+ # Complete keypoints
158
+ kp_dict, ln_dict = complete_keypoints(kp_dict, ln_dict, img_width, img_height)
159
+
160
+ # Initialize calibration
161
+ cam = FramebyFrameCalib(img_width, img_height)
162
+ cam.update(kp_dict, ln_dict)
163
+ cam_params = cam.heuristic_voting(refine_lines=True)
164
+
165
+ return cam_params
166
+
167
+ def main():
168
+ # Chemin vers votre image
169
+ image_path = "examples/input/cam3.jpg"
170
+
171
+ # Obtenir les paramètres de la caméra
172
+ camera_params = get_camera_parameters(image_path, cam3_line_dict)
173
+
174
+ # Afficher les paramètres
175
+ print("=== PARAMÈTRES DE LA CAMÉRA ===")
176
+ print(f"Position (mètres): {camera_params['cam_params']['position_meters']}")
177
+ print(f"Distance focale X: {camera_params['cam_params']['x_focal_length']:.2f}")
178
+ print(f"Distance focale Y: {camera_params['cam_params']['y_focal_length']:.2f}")
179
+ print(f"Point principal: {camera_params['cam_params']['principal_point']}")
180
+ print(f"Matrice de rotation:")
181
+ rotation_matrix = np.array(camera_params['cam_params']['rotation_matrix'])
182
+ print(rotation_matrix)
183
+
184
+ # Calcul des angles d'Euler
185
+ euler_angles = np.array([
186
+ np.arctan2(rotation_matrix[2,1], rotation_matrix[2,2]), # roll
187
+ np.arctan2(-rotation_matrix[2,0], np.sqrt(rotation_matrix[2,1]**2 + rotation_matrix[2,2]**2)), # pitch
188
+ np.arctan2(rotation_matrix[1,0], rotation_matrix[0,0]) # yaw
189
+ ]) * 180 / np.pi
190
+
191
+ print(f"Angles d'Euler (degrés):")
192
+ print(f" Roll: {euler_angles[0]:.1f}°")
193
+ print(f" Pitch: {euler_angles[1]:.1f}°")
194
+ print(f" Yaw: {euler_angles[2]:.1f}°")
195
+
196
+ # print(camera_params)
197
+
198
+ return camera_params
199
+
200
+ if __name__ == "__main__":
201
+ main()
inference.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import yaml
3
+ import torch
4
+ import argparse
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import torchvision.transforms as T
8
+ import torchvision.transforms.functional as f
9
+
10
+ from tqdm import tqdm
11
+ from PIL import Image
12
+ from matplotlib.patches import Polygon
13
+
14
+ from model.cls_hrnet import get_cls_net
15
+ from model.cls_hrnet_l import get_cls_net as get_cls_net_l
16
+
17
+ from utils.utils_calib import FramebyFrameCalib, pan_tilt_roll_to_orientation
18
+ from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, \
19
+ complete_keypoints, coords_to_dict
20
+
21
+
22
+ lines_coords = [[[0., 54.16, 0.], [16.5, 54.16, 0.]],
23
+ [[16.5, 13.84, 0.], [16.5, 54.16, 0.]],
24
+ [[16.5, 13.84, 0.], [0., 13.84, 0.]],
25
+ [[88.5, 54.16, 0.], [105., 54.16, 0.]],
26
+ [[88.5, 13.84, 0.], [88.5, 54.16, 0.]],
27
+ [[88.5, 13.84, 0.], [105., 13.84, 0.]],
28
+ [[0., 37.66, -2.44], [0., 30.34, -2.44]],
29
+ [[0., 37.66, 0.], [0., 37.66, -2.44]],
30
+ [[0., 30.34, 0.], [0., 30.34, -2.44]],
31
+ [[105., 37.66, -2.44], [105., 30.34, -2.44]],
32
+ [[105., 30.34, 0.], [105., 30.34, -2.44]],
33
+ [[105., 37.66, 0.], [105., 37.66, -2.44]],
34
+ [[52.5, 0., 0.], [52.5, 68, 0.]],
35
+ [[0., 68., 0.], [105., 68., 0.]],
36
+ [[0., 0., 0.], [0., 68., 0.]],
37
+ [[105., 0., 0.], [105., 68., 0.]],
38
+ [[0., 0., 0.], [105., 0., 0.]],
39
+ [[0., 43.16, 0.], [5.5, 43.16, 0.]],
40
+ [[5.5, 43.16, 0.], [5.5, 24.84, 0.]],
41
+ [[5.5, 24.84, 0.], [0., 24.84, 0.]],
42
+ [[99.5, 43.16, 0.], [105., 43.16, 0.]],
43
+ [[99.5, 43.16, 0.], [99.5, 24.84, 0.]],
44
+ [[99.5, 24.84, 0.], [105., 24.84, 0.]]]
45
+
46
+
47
+ def projection_from_cam_params(final_params_dict):
48
+ cam_params = final_params_dict["cam_params"]
49
+ x_focal_length = cam_params['x_focal_length']
50
+ y_focal_length = cam_params['y_focal_length']
51
+ principal_point = np.array(cam_params['principal_point'])
52
+ position_meters = np.array(cam_params['position_meters'])
53
+ rotation = np.array(cam_params['rotation_matrix'])
54
+
55
+ It = np.eye(4)[:-1]
56
+ It[:, -1] = -position_meters
57
+ Q = np.array([[x_focal_length, 0, principal_point[0]],
58
+ [0, y_focal_length, principal_point[1]],
59
+ [0, 0, 1]])
60
+ P = Q @ (rotation @ It)
61
+
62
+ return P
63
+
64
+
65
+ def inference(cam, frame, model, model_l, kp_threshold, line_threshold, pnl_refine):
66
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
67
+ frame = Image.fromarray(frame)
68
+
69
+ frame = f.to_tensor(frame).float().unsqueeze(0)
70
+ _, _, h_original, w_original = frame.size()
71
+ frame = frame if frame.size()[-1] == 960 else transform2(frame)
72
+ frame = frame.to(device)
73
+ b, c, h, w = frame.size()
74
+
75
+ with torch.no_grad():
76
+ heatmaps = model(frame)
77
+ heatmaps_l = model_l(frame)
78
+
79
+ kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
80
+ # print("\n"*4,"kp_coords: ", kp_coords)
81
+ line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
82
+ print("\n"*4,"line_coords_from_heatmap: ", line_coords)
83
+ kp_dict = coords_to_dict(kp_coords, threshold=kp_threshold)
84
+ lines_dict = coords_to_dict(line_coords, threshold=line_threshold)
85
+
86
+ # print("\n=== AVANT complete_keypoints ===")
87
+ # print("--- KEYPOINTS ---")
88
+ # print(f"Nombre de keypoints: {len(kp_dict[0])}")
89
+ # print("Structure des keypoints:")
90
+ for kp_key, kp_value in kp_dict[0].items():
91
+ print(f"{kp_key}: {kp_value}")
92
+
93
+ # print("\n--- LIGNES ---")
94
+ # print(f"Nombre de lignes: {len(lines_dict[0])}")
95
+ # print("Structure des lignes:")
96
+ for line_key, line_value in lines_dict[0].items():
97
+ print(f"{line_key}: {line_value}")
98
+
99
+ kp_dict, lines_dict = complete_keypoints(kp_dict[0], lines_dict[0], w=w, h=h, normalize=True)
100
+
101
+ # print("\n=== APRÈS complete_keypoints ===")
102
+ # print("--- KEYPOINTS ---")
103
+ # print(f"Nombre de keypoints: {len(kp_dict)}")
104
+ # print("Structure des keypoints:")
105
+ for kp_key, kp_value in kp_dict.items():
106
+ print(f"{kp_key}: {kp_value}")
107
+
108
+ # print("\n--- LIGNES ---")
109
+ # print(f"Nombre de lignes: {len(lines_dict)}")
110
+ # print("Structure des lignes:")
111
+ for line_key, line_value in lines_dict.items():
112
+ print(f"{line_key}: {line_value}")
113
+
114
+ cam.update(kp_dict, lines_dict)
115
+ final_params_dict = cam.heuristic_voting(refine_lines=pnl_refine)
116
+
117
+ return final_params_dict
118
+
119
+
120
+ def project(frame, P):
121
+
122
+ for line in lines_coords:
123
+ w1 = line[0]
124
+ w2 = line[1]
125
+ i1 = P @ np.array([w1[0]-105/2, w1[1]-68/2, w1[2], 1])
126
+ i2 = P @ np.array([w2[0]-105/2, w2[1]-68/2, w2[2], 1])
127
+ i1 /= i1[-1]
128
+ i2 /= i2[-1]
129
+ frame = cv2.line(frame, (int(i1[0]), int(i1[1])), (int(i2[0]), int(i2[1])), (255, 0, 0), 3)
130
+
131
+ r = 9.15
132
+ pts1, pts2, pts3 = [], [], []
133
+ base_pos = np.array([11-105/2, 68/2-68/2, 0., 0.])
134
+ for ang in np.linspace(37, 143, 50):
135
+ ang = np.deg2rad(ang)
136
+ pos = base_pos + np.array([r*np.sin(ang), r*np.cos(ang), 0., 1.])
137
+ ipos = P @ pos
138
+ ipos /= ipos[-1]
139
+ pts1.append([ipos[0], ipos[1]])
140
+
141
+ base_pos = np.array([94-105/2, 68/2-68/2, 0., 0.])
142
+ for ang in np.linspace(217, 323, 200):
143
+ ang = np.deg2rad(ang)
144
+ pos = base_pos + np.array([r*np.sin(ang), r*np.cos(ang), 0., 1.])
145
+ ipos = P @ pos
146
+ ipos /= ipos[-1]
147
+ pts2.append([ipos[0], ipos[1]])
148
+
149
+ base_pos = np.array([0, 0, 0., 0.])
150
+ for ang in np.linspace(0, 360, 500):
151
+ ang = np.deg2rad(ang)
152
+ pos = base_pos + np.array([r*np.sin(ang), r*np.cos(ang), 0., 1.])
153
+ ipos = P @ pos
154
+ ipos /= ipos[-1]
155
+ pts3.append([ipos[0], ipos[1]])
156
+
157
+ XEllipse1 = np.array(pts1, np.int32)
158
+ XEllipse2 = np.array(pts2, np.int32)
159
+ XEllipse3 = np.array(pts3, np.int32)
160
+ frame = cv2.polylines(frame, [XEllipse1], False, (255, 0, 0), 3)
161
+ frame = cv2.polylines(frame, [XEllipse2], False, (255, 0, 0), 3)
162
+ frame = cv2.polylines(frame, [XEllipse3], False, (255, 0, 0), 3)
163
+
164
+ return frame
165
+
166
+
167
+ def process_input(input_path, input_type, model_kp, model_line, kp_threshold, line_threshold, pnl_refine,
168
+ save_path, display):
169
+
170
+ cap = cv2.VideoCapture(input_path)
171
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
172
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
173
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
174
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
175
+
176
+ cam = FramebyFrameCalib(iwidth=frame_width, iheight=frame_height, denormalize=True)
177
+
178
+ if input_type == 'video':
179
+ cap = cv2.VideoCapture(input_path)
180
+ if save_path != "":
181
+ out = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
182
+
183
+ pbar = tqdm(total=total_frames)
184
+
185
+ while cap.isOpened():
186
+ ret, frame = cap.read()
187
+ if not ret:
188
+ break
189
+
190
+ final_params_dict = inference(cam, frame, model, model_l, kp_threshold, line_threshold, pnl_refine)
191
+ if final_params_dict is not None:
192
+ P = projection_from_cam_params(final_params_dict)
193
+ projected_frame = project(frame, P)
194
+ else:
195
+ projected_frame = frame
196
+
197
+ if save_path != "":
198
+ out.write(projected_frame)
199
+
200
+ if display:
201
+ cv2.imshow('Projected Frame', projected_frame)
202
+ if cv2.waitKey(1) & 0xFF == ord('q'):
203
+ break
204
+
205
+ pbar.update(1)
206
+
207
+ cap.release()
208
+ if save_path != "":
209
+ out.release()
210
+ cv2.destroyAllWindows()
211
+
212
+ elif input_type == 'image':
213
+ frame = cv2.imread(input_path)
214
+ if frame is None:
215
+ print(f"Error: Unable to read the image {input_path}")
216
+ return
217
+
218
+ final_params_dict = inference(cam, frame, model, model_l, kp_threshold, line_threshold, pnl_refine)
219
+
220
+ print("\n"*4,"final_params_dict: ", final_params_dict)
221
+
222
+ if final_params_dict is not None:
223
+ P = projection_from_cam_params(final_params_dict)
224
+ projected_frame = project(frame, P)
225
+ else:
226
+ projected_frame = frame
227
+
228
+ if save_path != "":
229
+ cv2.imwrite(save_path, projected_frame)
230
+ else:
231
+ plt.imshow(cv2.cvtColor(projected_frame, cv2.COLOR_BGR2RGB))
232
+ plt.axis('off')
233
+ plt.show()
234
+
235
+ if __name__ == "__main__":
236
+
237
+ parser = argparse.ArgumentParser(description="Process video or image and plot lines on each frame.")
238
+ parser.add_argument("--weights_kp", type=str, help="Path to the model for keypoint inference.")
239
+ parser.add_argument("--weights_line", type=str, help="Path to the model for line projection.")
240
+ parser.add_argument("--kp_threshold", type=float, default=0.3434, help="Threshold for keypoint detection.")
241
+ parser.add_argument("--line_threshold", type=float, default=0.7867, help="Threshold for line detection.")
242
+ parser.add_argument("--pnl_refine", action="store_true", help="Enable PnL refinement module.")
243
+ parser.add_argument("--device", type=str, default="cuda:0", help="CPU or CUDA device index")
244
+ parser.add_argument("--input_path", type=str, required=True, help="Path to the input video or image file.")
245
+ parser.add_argument("--input_type", type=str, choices=['video', 'image'], required=True,
246
+ help="Type of input: 'video' or 'image'.")
247
+ parser.add_argument("--save_path", type=str, default="", help="Path to save the processed video.")
248
+ parser.add_argument("--display", action="store_true", help="Enable real-time display.")
249
+ args = parser.parse_args()
250
+
251
+
252
+ input_path = args.input_path
253
+ input_type = args.input_type
254
+ model_kp = args.weights_kp
255
+ model_line = args.weights_line
256
+ pnl_refine = args.pnl_refine
257
+ save_path = args.save_path
258
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
259
+ display = args.display and input_type == 'video'
260
+ kp_threshold = args.kp_threshold
261
+ line_threshold = args.line_threshold
262
+
263
+ cfg = yaml.safe_load(open("config/hrnetv2_w48.yaml", 'r'))
264
+ cfg_l = yaml.safe_load(open("config/hrnetv2_w48_l.yaml", 'r'))
265
+
266
+ loaded_state = torch.load(args.weights_kp, map_location=device, weights_only=True)
267
+ model = get_cls_net(cfg)
268
+ model.load_state_dict(loaded_state)
269
+ model.to(device)
270
+ model.eval()
271
+
272
+ loaded_state_l = torch.load(args.weights_line, map_location=device, weights_only=True)
273
+ model_l = get_cls_net_l(cfg_l)
274
+ model_l.load_state_dict(loaded_state_l)
275
+ model_l.to(device)
276
+ model_l.eval()
277
+
278
+ transform2 = T.Resize((540, 960))
279
+
280
+ process_input(input_path, input_type, model_kp, model_line, kp_threshold, line_threshold, pnl_refine,
281
+ save_path, display)
282
+
283
+
284
+ # python inference.py --weights_kp "models/SV_FT_TSWC_kp" --weights_line "models/SV_FT_TSWC_lines" --input_path "examples/input/FootDrone.mp4" --input_type "video" --save_path "examples/output/video.mp4" --kp_threshold 0.15 --line_threshold 0.15 --pnl_refine
285
+ # python inference.py --weights_kp "SV_FT_TSWC_kp" --weights_line "SV_FT_TSWC_lines" --input_path "examples/input/FootDrone.jpg" --input_type "image" --save_path "examples/output/FootDrone_inf.jpg" --kp_threshold 0.15 --line_threshold 0.15
286
+ # python inference.py --weights_kp "models/SV_FT_TSWC_kp" --weights_line "models/SV_FT_TSWC_lines" --input_path "examples/input/fisheye_messi.png" --input_type "image" --save_path "examples/output/fisheye_messi_inf.png" --kp_threshold 0.15 --line_threshold 0.15
model/cls_hrnet.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import os
6
+ import logging
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ BatchNorm2d = nn.BatchNorm2d
14
+ BN_MOMENTUM = 0.1
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def conv3x3(in_planes, out_planes, stride=1):
19
+ """3x3 convolution with padding"""
20
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
21
+ stride=stride, padding=1, bias=False)
22
+
23
+
24
+ class BasicBlock(nn.Module):
25
+ expansion = 1
26
+
27
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
28
+ super(BasicBlock, self).__init__()
29
+ self.conv1 = conv3x3(inplanes, planes, stride)
30
+ self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
31
+ self.relu = nn.ReLU(inplace=True)
32
+ self.conv2 = conv3x3(planes, planes)
33
+ self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
34
+ self.downsample = downsample
35
+ self.stride = stride
36
+
37
+ def forward(self, x):
38
+ residual = x
39
+
40
+ out = self.conv1(x)
41
+ out = self.bn1(out)
42
+ out = self.relu(out)
43
+
44
+ out = self.conv2(out)
45
+ out = self.bn2(out)
46
+
47
+ if self.downsample is not None:
48
+ residual = self.downsample(x)
49
+
50
+ out += residual
51
+ out = self.relu(out)
52
+
53
+ return out
54
+
55
+
56
+ class Bottleneck(nn.Module):
57
+ expansion = 4
58
+
59
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
60
+ super(Bottleneck, self).__init__()
61
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
62
+ self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
63
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
64
+ padding=1, bias=False)
65
+ self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
66
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
67
+ bias=False)
68
+ self.bn3 = BatchNorm2d(planes * self.expansion,
69
+ momentum=BN_MOMENTUM)
70
+ self.relu = nn.ReLU(inplace=True)
71
+ self.downsample = downsample
72
+ self.stride = stride
73
+
74
+ def forward(self, x):
75
+ residual = x
76
+
77
+ out = self.conv1(x)
78
+ out = self.bn1(out)
79
+ out = self.relu(out)
80
+
81
+ out = self.conv2(out)
82
+ out = self.bn2(out)
83
+ out = self.relu(out)
84
+
85
+ out = self.conv3(out)
86
+ out = self.bn3(out)
87
+
88
+ if self.downsample is not None:
89
+ residual = self.downsample(x)
90
+
91
+ out += residual
92
+ out = self.relu(out)
93
+
94
+ return out
95
+
96
+
97
+ class HighResolutionModule(nn.Module):
98
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
99
+ num_channels, fuse_method, multi_scale_output=True):
100
+ super(HighResolutionModule, self).__init__()
101
+ self._check_branches(
102
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
103
+
104
+ self.num_inchannels = num_inchannels
105
+ self.fuse_method = fuse_method
106
+ self.num_branches = num_branches
107
+
108
+ self.multi_scale_output = multi_scale_output
109
+
110
+ self.branches = self._make_branches(
111
+ num_branches, blocks, num_blocks, num_channels)
112
+ self.fuse_layers = self._make_fuse_layers()
113
+ self.relu = nn.ReLU(inplace=True)
114
+
115
+ def _check_branches(self, num_branches, blocks, num_blocks,
116
+ num_inchannels, num_channels):
117
+ if num_branches != len(num_blocks):
118
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
119
+ num_branches, len(num_blocks))
120
+ logger.error(error_msg)
121
+ raise ValueError(error_msg)
122
+
123
+ if num_branches != len(num_channels):
124
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
125
+ num_branches, len(num_channels))
126
+ logger.error(error_msg)
127
+ raise ValueError(error_msg)
128
+
129
+ if num_branches != len(num_inchannels):
130
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
131
+ num_branches, len(num_inchannels))
132
+ logger.error(error_msg)
133
+ raise ValueError(error_msg)
134
+
135
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
136
+ stride=1):
137
+ downsample = None
138
+ if stride != 1 or \
139
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
140
+ downsample = nn.Sequential(
141
+ nn.Conv2d(self.num_inchannels[branch_index],
142
+ num_channels[branch_index] * block.expansion,
143
+ kernel_size=1, stride=stride, bias=False),
144
+ BatchNorm2d(num_channels[branch_index] * block.expansion,
145
+ momentum=BN_MOMENTUM),
146
+ )
147
+
148
+ layers = []
149
+ layers.append(block(self.num_inchannels[branch_index],
150
+ num_channels[branch_index], stride, downsample))
151
+ self.num_inchannels[branch_index] = \
152
+ num_channels[branch_index] * block.expansion
153
+ for i in range(1, num_blocks[branch_index]):
154
+ layers.append(block(self.num_inchannels[branch_index],
155
+ num_channels[branch_index]))
156
+
157
+ return nn.Sequential(*layers)
158
+
159
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
160
+ branches = []
161
+
162
+ for i in range(num_branches):
163
+ branches.append(
164
+ self._make_one_branch(i, block, num_blocks, num_channels))
165
+
166
+ return nn.ModuleList(branches)
167
+
168
+ def _make_fuse_layers(self):
169
+ if self.num_branches == 1:
170
+ return None
171
+
172
+ num_branches = self.num_branches
173
+ num_inchannels = self.num_inchannels
174
+ fuse_layers = []
175
+ for i in range(num_branches if self.multi_scale_output else 1):
176
+ fuse_layer = []
177
+ for j in range(num_branches):
178
+ if j > i:
179
+ fuse_layer.append(nn.Sequential(
180
+ nn.Conv2d(num_inchannels[j],
181
+ num_inchannels[i],
182
+ 1,
183
+ 1,
184
+ 0,
185
+ bias=False),
186
+ BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
187
+ # nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
188
+ elif j == i:
189
+ fuse_layer.append(None)
190
+ else:
191
+ conv3x3s = []
192
+ for k in range(i - j):
193
+ if k == i - j - 1:
194
+ num_outchannels_conv3x3 = num_inchannels[i]
195
+ conv3x3s.append(nn.Sequential(
196
+ nn.Conv2d(num_inchannels[j],
197
+ num_outchannels_conv3x3,
198
+ 3, 2, 1, bias=False),
199
+ BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM)))
200
+ else:
201
+ num_outchannels_conv3x3 = num_inchannels[j]
202
+ conv3x3s.append(nn.Sequential(
203
+ nn.Conv2d(num_inchannels[j],
204
+ num_outchannels_conv3x3,
205
+ 3, 2, 1, bias=False),
206
+ BatchNorm2d(num_outchannels_conv3x3,
207
+ momentum=BN_MOMENTUM),
208
+ nn.ReLU(inplace=True)))
209
+ fuse_layer.append(nn.Sequential(*conv3x3s))
210
+ fuse_layers.append(nn.ModuleList(fuse_layer))
211
+
212
+ return nn.ModuleList(fuse_layers)
213
+
214
+ def get_num_inchannels(self):
215
+ return self.num_inchannels
216
+
217
+ def forward(self, x):
218
+ if self.num_branches == 1:
219
+ return [self.branches[0](x[0])]
220
+
221
+ for i in range(self.num_branches):
222
+ x[i] = self.branches[i](x[i])
223
+
224
+ x_fuse = []
225
+ for i in range(len(self.fuse_layers)):
226
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
227
+ for j in range(1, self.num_branches):
228
+ if i == j:
229
+ y = y + x[j]
230
+ elif j > i:
231
+ y = y + F.interpolate(
232
+ self.fuse_layers[i][j](x[j]),
233
+ size=[x[i].shape[2], x[i].shape[3]],
234
+ mode='bilinear')
235
+ else:
236
+ y = y + self.fuse_layers[i][j](x[j])
237
+ x_fuse.append(self.relu(y))
238
+
239
+ return x_fuse
240
+
241
+
242
+ blocks_dict = {
243
+ 'BASIC': BasicBlock,
244
+ 'BOTTLENECK': Bottleneck
245
+ }
246
+
247
+
248
+ class HighResolutionNet(nn.Module):
249
+
250
+ def __init__(self, config, **kwargs):
251
+ self.inplanes = 64
252
+ extra = config['MODEL']['EXTRA']
253
+ super(HighResolutionNet, self).__init__()
254
+
255
+ # stem net
256
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=2, padding=1,
257
+ bias=False)
258
+ self.bn1 = BatchNorm2d(self.inplanes, momentum=BN_MOMENTUM)
259
+ self.conv2 = nn.Conv2d(self.inplanes, self.inplanes, kernel_size=3, stride=2, padding=1,
260
+ bias=False)
261
+ self.bn2 = BatchNorm2d(self.inplanes, momentum=BN_MOMENTUM)
262
+ self.relu = nn.ReLU(inplace=True)
263
+ self.sf = nn.Softmax(dim=1)
264
+ self.layer1 = self._make_layer(Bottleneck, 64, 64, 4)
265
+
266
+ self.stage2_cfg = extra['STAGE2']
267
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
268
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
269
+ num_channels = [
270
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
271
+ self.transition1 = self._make_transition_layer(
272
+ [256], num_channels)
273
+ self.stage2, pre_stage_channels = self._make_stage(
274
+ self.stage2_cfg, num_channels)
275
+
276
+ self.stage3_cfg = extra['STAGE3']
277
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
278
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
279
+ num_channels = [
280
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
281
+ self.transition2 = self._make_transition_layer(
282
+ pre_stage_channels, num_channels)
283
+ self.stage3, pre_stage_channels = self._make_stage(
284
+ self.stage3_cfg, num_channels)
285
+
286
+ self.stage4_cfg = extra['STAGE4']
287
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
288
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
289
+ num_channels = [
290
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
291
+ self.transition3 = self._make_transition_layer(
292
+ pre_stage_channels, num_channels)
293
+ self.stage4, pre_stage_channels = self._make_stage(
294
+ self.stage4_cfg, num_channels, multi_scale_output=True)
295
+
296
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
297
+ final_inp_channels = sum(pre_stage_channels) + self.inplanes
298
+
299
+ self.head = nn.Sequential(nn.Sequential(
300
+ nn.Conv2d(
301
+ in_channels=final_inp_channels,
302
+ out_channels=final_inp_channels,
303
+ kernel_size=1),
304
+ BatchNorm2d(final_inp_channels, momentum=BN_MOMENTUM),
305
+ nn.ReLU(inplace=True),
306
+ nn.Conv2d(
307
+ in_channels=final_inp_channels,
308
+ out_channels=config['MODEL']['NUM_JOINTS'],
309
+ kernel_size=extra['FINAL_CONV_KERNEL']),
310
+ nn.Softmax(dim=1)))
311
+
312
+
313
+
314
+ def _make_head(self, x, x_skip):
315
+ x = self.upsample(x)
316
+ x = torch.cat([x, x_skip], dim=1)
317
+ x = self.head(x)
318
+
319
+ return x
320
+
321
+ def _make_transition_layer(
322
+ self, num_channels_pre_layer, num_channels_cur_layer):
323
+ num_branches_cur = len(num_channels_cur_layer)
324
+ num_branches_pre = len(num_channels_pre_layer)
325
+
326
+ transition_layers = []
327
+ for i in range(num_branches_cur):
328
+ if i < num_branches_pre:
329
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
330
+ transition_layers.append(nn.Sequential(
331
+ nn.Conv2d(num_channels_pre_layer[i],
332
+ num_channels_cur_layer[i],
333
+ 3,
334
+ 1,
335
+ 1,
336
+ bias=False),
337
+ BatchNorm2d(
338
+ num_channels_cur_layer[i], momentum=BN_MOMENTUM),
339
+ nn.ReLU(inplace=True)))
340
+ else:
341
+ transition_layers.append(None)
342
+ else:
343
+ conv3x3s = []
344
+ for j in range(i + 1 - num_branches_pre):
345
+ inchannels = num_channels_pre_layer[-1]
346
+ outchannels = num_channels_cur_layer[i] \
347
+ if j == i - num_branches_pre else inchannels
348
+ conv3x3s.append(nn.Sequential(
349
+ nn.Conv2d(
350
+ inchannels, outchannels, 3, 2, 1, bias=False),
351
+ BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
352
+ nn.ReLU(inplace=True)))
353
+ transition_layers.append(nn.Sequential(*conv3x3s))
354
+
355
+ return nn.ModuleList(transition_layers)
356
+
357
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
358
+ downsample = None
359
+ if stride != 1 or inplanes != planes * block.expansion:
360
+ downsample = nn.Sequential(
361
+ nn.Conv2d(inplanes, planes * block.expansion,
362
+ kernel_size=1, stride=stride, bias=False),
363
+ BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
364
+ )
365
+
366
+ layers = []
367
+ layers.append(block(inplanes, planes, stride, downsample))
368
+ inplanes = planes * block.expansion
369
+ for i in range(1, blocks):
370
+ layers.append(block(inplanes, planes))
371
+
372
+ return nn.Sequential(*layers)
373
+
374
+ def _make_stage(self, layer_config, num_inchannels,
375
+ multi_scale_output=True):
376
+ num_modules = layer_config['NUM_MODULES']
377
+ num_branches = layer_config['NUM_BRANCHES']
378
+ num_blocks = layer_config['NUM_BLOCKS']
379
+ num_channels = layer_config['NUM_CHANNELS']
380
+ block = blocks_dict[layer_config['BLOCK']]
381
+ fuse_method = layer_config['FUSE_METHOD']
382
+
383
+ modules = []
384
+ for i in range(num_modules):
385
+ # multi_scale_output is only used last module
386
+ if not multi_scale_output and i == num_modules - 1:
387
+ reset_multi_scale_output = False
388
+ else:
389
+ reset_multi_scale_output = True
390
+ modules.append(
391
+ HighResolutionModule(num_branches,
392
+ block,
393
+ num_blocks,
394
+ num_inchannels,
395
+ num_channels,
396
+ fuse_method,
397
+ reset_multi_scale_output)
398
+ )
399
+ num_inchannels = modules[-1].get_num_inchannels()
400
+
401
+ return nn.Sequential(*modules), num_inchannels
402
+
403
+ def forward(self, x):
404
+ # h, w = x.size(2), x.size(3)
405
+ x = self.conv1(x)
406
+ x_skip = x.clone()
407
+ x = self.bn1(x)
408
+ x = self.relu(x)
409
+ x = self.conv2(x)
410
+ x = self.bn2(x)
411
+ x = self.relu(x)
412
+ x = self.layer1(x)
413
+
414
+ x_list = []
415
+ for i in range(self.stage2_cfg['NUM_BRANCHES']):
416
+ if self.transition1[i] is not None:
417
+ x_list.append(self.transition1[i](x))
418
+ else:
419
+ x_list.append(x)
420
+ y_list = self.stage2(x_list)
421
+
422
+ x_list = []
423
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
424
+ if self.transition2[i] is not None:
425
+ x_list.append(self.transition2[i](y_list[-1]))
426
+ else:
427
+ x_list.append(y_list[i])
428
+ y_list = self.stage3(x_list)
429
+
430
+ x_list = []
431
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
432
+ if self.transition3[i] is not None:
433
+ x_list.append(self.transition3[i](y_list[-1]))
434
+ else:
435
+ x_list.append(y_list[i])
436
+ x = self.stage4(x_list)
437
+
438
+ # Head Part
439
+ height, width = x[0].size(2), x[0].size(3)
440
+ x1 = F.interpolate(x[1], size=(height, width), mode='bilinear', align_corners=False)
441
+ x2 = F.interpolate(x[2], size=(height, width), mode='bilinear', align_corners=False)
442
+ x3 = F.interpolate(x[3], size=(height, width), mode='bilinear', align_corners=False)
443
+ x = torch.cat([x[0], x1, x2, x3], 1)
444
+ x = self._make_head(x, x_skip)
445
+
446
+ return x
447
+
448
+ def init_weights(self, pretrained=''):
449
+ logger.info('=> init weights from normal distribution')
450
+ for m in self.modules():
451
+ if isinstance(m, nn.Conv2d):
452
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
453
+ #nn.init.normal_(m.weight, std=0.001)
454
+ #nn.init.constant_(m.bias, 0)
455
+ elif isinstance(m, nn.BatchNorm2d):
456
+ nn.init.constant_(m.weight, 1)
457
+ nn.init.constant_(m.bias, 0)
458
+ if pretrained != '':
459
+ if os.path.isfile(pretrained):
460
+ pretrained_dict = torch.load(pretrained)
461
+ logger.info('=> loading pretrained model {}'.format(pretrained))
462
+ print('=> loading pretrained model {}'.format(pretrained))
463
+ model_dict = self.state_dict()
464
+ pretrained_dict = {k: v for k, v in pretrained_dict.items()
465
+ if k in model_dict.keys()}
466
+ for k, _ in pretrained_dict.items():
467
+ logger.info(
468
+ '=> loading {} pretrained model {}'.format(k, pretrained))
469
+ #print('=> loading {} pretrained model {}'.format(k, pretrained))
470
+ model_dict.update(pretrained_dict)
471
+ self.load_state_dict(model_dict)
472
+ else:
473
+ sys.exit(f'Weights {pretrained} not found.')
474
+
475
+
476
+ def get_cls_net(config, pretrained='', **kwargs):
477
+ model = HighResolutionNet(config, **kwargs)
478
+ model.init_weights(pretrained)
479
+ return model
model/cls_hrnet_l.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import os
6
+ import logging
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ BatchNorm2d = nn.BatchNorm2d
14
+ BN_MOMENTUM = 0.1
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def conv3x3(in_planes, out_planes, stride=1):
19
+ """3x3 convolution with padding"""
20
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
21
+ stride=stride, padding=1, bias=False)
22
+
23
+
24
+ class BasicBlock(nn.Module):
25
+ expansion = 1
26
+
27
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
28
+ super(BasicBlock, self).__init__()
29
+ self.conv1 = conv3x3(inplanes, planes, stride)
30
+ self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
31
+ self.relu = nn.ReLU(inplace=True)
32
+ self.conv2 = conv3x3(planes, planes)
33
+ self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
34
+ self.downsample = downsample
35
+ self.stride = stride
36
+
37
+ def forward(self, x):
38
+ residual = x
39
+
40
+ out = self.conv1(x)
41
+ out = self.bn1(out)
42
+ out = self.relu(out)
43
+
44
+ out = self.conv2(out)
45
+ out = self.bn2(out)
46
+
47
+ if self.downsample is not None:
48
+ residual = self.downsample(x)
49
+
50
+ out += residual
51
+ out = self.relu(out)
52
+
53
+ return out
54
+
55
+
56
+ class Bottleneck(nn.Module):
57
+ expansion = 4
58
+
59
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
60
+ super(Bottleneck, self).__init__()
61
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
62
+ self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
63
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
64
+ padding=1, bias=False)
65
+ self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
66
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
67
+ bias=False)
68
+ self.bn3 = BatchNorm2d(planes * self.expansion,
69
+ momentum=BN_MOMENTUM)
70
+ self.relu = nn.ReLU(inplace=True)
71
+ self.downsample = downsample
72
+ self.stride = stride
73
+
74
+ def forward(self, x):
75
+ residual = x
76
+
77
+ out = self.conv1(x)
78
+ out = self.bn1(out)
79
+ out = self.relu(out)
80
+
81
+ out = self.conv2(out)
82
+ out = self.bn2(out)
83
+ out = self.relu(out)
84
+
85
+ out = self.conv3(out)
86
+ out = self.bn3(out)
87
+
88
+ if self.downsample is not None:
89
+ residual = self.downsample(x)
90
+
91
+ out += residual
92
+ out = self.relu(out)
93
+
94
+ return out
95
+
96
+
97
+ class HighResolutionModule(nn.Module):
98
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
99
+ num_channels, fuse_method, multi_scale_output=True):
100
+ super(HighResolutionModule, self).__init__()
101
+ self._check_branches(
102
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
103
+
104
+ self.num_inchannels = num_inchannels
105
+ self.fuse_method = fuse_method
106
+ self.num_branches = num_branches
107
+
108
+ self.multi_scale_output = multi_scale_output
109
+
110
+ self.branches = self._make_branches(
111
+ num_branches, blocks, num_blocks, num_channels)
112
+ self.fuse_layers = self._make_fuse_layers()
113
+ self.relu = nn.ReLU(inplace=True)
114
+
115
+ def _check_branches(self, num_branches, blocks, num_blocks,
116
+ num_inchannels, num_channels):
117
+ if num_branches != len(num_blocks):
118
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
119
+ num_branches, len(num_blocks))
120
+ logger.error(error_msg)
121
+ raise ValueError(error_msg)
122
+
123
+ if num_branches != len(num_channels):
124
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
125
+ num_branches, len(num_channels))
126
+ logger.error(error_msg)
127
+ raise ValueError(error_msg)
128
+
129
+ if num_branches != len(num_inchannels):
130
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
131
+ num_branches, len(num_inchannels))
132
+ logger.error(error_msg)
133
+ raise ValueError(error_msg)
134
+
135
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
136
+ stride=1):
137
+ downsample = None
138
+ if stride != 1 or \
139
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
140
+ downsample = nn.Sequential(
141
+ nn.Conv2d(self.num_inchannels[branch_index],
142
+ num_channels[branch_index] * block.expansion,
143
+ kernel_size=1, stride=stride, bias=False),
144
+ BatchNorm2d(num_channels[branch_index] * block.expansion,
145
+ momentum=BN_MOMENTUM),
146
+ )
147
+
148
+ layers = []
149
+ layers.append(block(self.num_inchannels[branch_index],
150
+ num_channels[branch_index], stride, downsample))
151
+ self.num_inchannels[branch_index] = \
152
+ num_channels[branch_index] * block.expansion
153
+ for i in range(1, num_blocks[branch_index]):
154
+ layers.append(block(self.num_inchannels[branch_index],
155
+ num_channels[branch_index]))
156
+
157
+ return nn.Sequential(*layers)
158
+
159
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
160
+ branches = []
161
+
162
+ for i in range(num_branches):
163
+ branches.append(
164
+ self._make_one_branch(i, block, num_blocks, num_channels))
165
+
166
+ return nn.ModuleList(branches)
167
+
168
+ def _make_fuse_layers(self):
169
+ if self.num_branches == 1:
170
+ return None
171
+
172
+ num_branches = self.num_branches
173
+ num_inchannels = self.num_inchannels
174
+ fuse_layers = []
175
+ for i in range(num_branches if self.multi_scale_output else 1):
176
+ fuse_layer = []
177
+ for j in range(num_branches):
178
+ if j > i:
179
+ fuse_layer.append(nn.Sequential(
180
+ nn.Conv2d(num_inchannels[j],
181
+ num_inchannels[i],
182
+ 1,
183
+ 1,
184
+ 0,
185
+ bias=False),
186
+ BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
187
+ # nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
188
+ elif j == i:
189
+ fuse_layer.append(None)
190
+ else:
191
+ conv3x3s = []
192
+ for k in range(i - j):
193
+ if k == i - j - 1:
194
+ num_outchannels_conv3x3 = num_inchannels[i]
195
+ conv3x3s.append(nn.Sequential(
196
+ nn.Conv2d(num_inchannels[j],
197
+ num_outchannels_conv3x3,
198
+ 3, 2, 1, bias=False),
199
+ BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM)))
200
+ else:
201
+ num_outchannels_conv3x3 = num_inchannels[j]
202
+ conv3x3s.append(nn.Sequential(
203
+ nn.Conv2d(num_inchannels[j],
204
+ num_outchannels_conv3x3,
205
+ 3, 2, 1, bias=False),
206
+ BatchNorm2d(num_outchannels_conv3x3,
207
+ momentum=BN_MOMENTUM),
208
+ nn.ReLU(inplace=True)))
209
+ fuse_layer.append(nn.Sequential(*conv3x3s))
210
+ fuse_layers.append(nn.ModuleList(fuse_layer))
211
+
212
+ return nn.ModuleList(fuse_layers)
213
+
214
+ def get_num_inchannels(self):
215
+ return self.num_inchannels
216
+
217
+ def forward(self, x):
218
+ if self.num_branches == 1:
219
+ return [self.branches[0](x[0])]
220
+
221
+ for i in range(self.num_branches):
222
+ x[i] = self.branches[i](x[i])
223
+
224
+ x_fuse = []
225
+ for i in range(len(self.fuse_layers)):
226
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
227
+ for j in range(1, self.num_branches):
228
+ if i == j:
229
+ y = y + x[j]
230
+ elif j > i:
231
+ y = y + F.interpolate(
232
+ self.fuse_layers[i][j](x[j]),
233
+ size=[x[i].shape[2], x[i].shape[3]],
234
+ mode='bilinear')
235
+ else:
236
+ y = y + self.fuse_layers[i][j](x[j])
237
+ x_fuse.append(self.relu(y))
238
+
239
+ return x_fuse
240
+
241
+
242
+ blocks_dict = {
243
+ 'BASIC': BasicBlock,
244
+ 'BOTTLENECK': Bottleneck
245
+ }
246
+
247
+
248
+ class HighResolutionNet(nn.Module):
249
+
250
+ def __init__(self, config, **kwargs):
251
+ self.inplanes = 64
252
+ extra = config['MODEL']['EXTRA']
253
+ super(HighResolutionNet, self).__init__()
254
+
255
+ # stem net
256
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=2, padding=1,
257
+ bias=False)
258
+ self.bn1 = BatchNorm2d(self.inplanes, momentum=BN_MOMENTUM)
259
+ self.conv2 = nn.Conv2d(self.inplanes, self.inplanes, kernel_size=3, stride=2, padding=1,
260
+ bias=False)
261
+ self.bn2 = BatchNorm2d(self.inplanes, momentum=BN_MOMENTUM)
262
+ self.relu = nn.ReLU(inplace=True)
263
+ self.sf = nn.Softmax(dim=1)
264
+ self.layer1 = self._make_layer(Bottleneck, self.inplanes, self.inplanes, 4)
265
+
266
+ self.stage2_cfg = extra['STAGE2']
267
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
268
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
269
+ num_channels = [
270
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
271
+ self.transition1 = self._make_transition_layer(
272
+ [256], num_channels)
273
+ self.stage2, pre_stage_channels = self._make_stage(
274
+ self.stage2_cfg, num_channels)
275
+
276
+ self.stage3_cfg = extra['STAGE3']
277
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
278
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
279
+ num_channels = [
280
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
281
+ self.transition2 = self._make_transition_layer(
282
+ pre_stage_channels, num_channels)
283
+ self.stage3, pre_stage_channels = self._make_stage(
284
+ self.stage3_cfg, num_channels)
285
+
286
+ self.stage4_cfg = extra['STAGE4']
287
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
288
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
289
+ num_channels = [
290
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
291
+ self.transition3 = self._make_transition_layer(
292
+ pre_stage_channels, num_channels)
293
+ self.stage4, pre_stage_channels = self._make_stage(
294
+ self.stage4_cfg, num_channels, multi_scale_output=True)
295
+
296
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
297
+ final_inp_channels = sum(pre_stage_channels) + self.inplanes
298
+
299
+ self.head = nn.Sequential(nn.Sequential(
300
+ nn.Conv2d(
301
+ in_channels=final_inp_channels,
302
+ out_channels=final_inp_channels,
303
+ kernel_size=1),
304
+ BatchNorm2d(final_inp_channels, momentum=BN_MOMENTUM),
305
+ nn.ReLU(inplace=True),
306
+ nn.Conv2d(
307
+ in_channels=final_inp_channels,
308
+ out_channels=config['MODEL']['NUM_JOINTS'],
309
+ kernel_size=extra['FINAL_CONV_KERNEL']),
310
+ nn.Sigmoid()))
311
+
312
+
313
+
314
+
315
+ def _make_head(self, x, x_skip):
316
+ x = self.upsample(x)
317
+ x = torch.cat([x, x_skip], dim=1)
318
+ x = self.head(x)
319
+
320
+ return x
321
+
322
+ def _make_transition_layer(
323
+ self, num_channels_pre_layer, num_channels_cur_layer):
324
+ num_branches_cur = len(num_channels_cur_layer)
325
+ num_branches_pre = len(num_channels_pre_layer)
326
+
327
+ transition_layers = []
328
+ for i in range(num_branches_cur):
329
+ if i < num_branches_pre:
330
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
331
+ transition_layers.append(nn.Sequential(
332
+ nn.Conv2d(num_channels_pre_layer[i],
333
+ num_channels_cur_layer[i],
334
+ 3,
335
+ 1,
336
+ 1,
337
+ bias=False),
338
+ BatchNorm2d(
339
+ num_channels_cur_layer[i], momentum=BN_MOMENTUM),
340
+ nn.ReLU(inplace=True)))
341
+ else:
342
+ transition_layers.append(None)
343
+ else:
344
+ conv3x3s = []
345
+ for j in range(i + 1 - num_branches_pre):
346
+ inchannels = num_channels_pre_layer[-1]
347
+ outchannels = num_channels_cur_layer[i] \
348
+ if j == i - num_branches_pre else inchannels
349
+ conv3x3s.append(nn.Sequential(
350
+ nn.Conv2d(
351
+ inchannels, outchannels, 3, 2, 1, bias=False),
352
+ BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
353
+ nn.ReLU(inplace=True)))
354
+ transition_layers.append(nn.Sequential(*conv3x3s))
355
+
356
+ return nn.ModuleList(transition_layers)
357
+
358
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
359
+ downsample = None
360
+ if stride != 1 or inplanes != planes * block.expansion:
361
+ downsample = nn.Sequential(
362
+ nn.Conv2d(inplanes, planes * block.expansion,
363
+ kernel_size=1, stride=stride, bias=False),
364
+ BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
365
+ )
366
+
367
+ layers = []
368
+ layers.append(block(inplanes, planes, stride, downsample))
369
+ inplanes = planes * block.expansion
370
+ for i in range(1, blocks):
371
+ layers.append(block(inplanes, planes))
372
+
373
+ return nn.Sequential(*layers)
374
+
375
+ def _make_stage(self, layer_config, num_inchannels,
376
+ multi_scale_output=True):
377
+ num_modules = layer_config['NUM_MODULES']
378
+ num_branches = layer_config['NUM_BRANCHES']
379
+ num_blocks = layer_config['NUM_BLOCKS']
380
+ num_channels = layer_config['NUM_CHANNELS']
381
+ block = blocks_dict[layer_config['BLOCK']]
382
+ fuse_method = layer_config['FUSE_METHOD']
383
+
384
+ modules = []
385
+ for i in range(num_modules):
386
+ # multi_scale_output is only used last module
387
+ if not multi_scale_output and i == num_modules - 1:
388
+ reset_multi_scale_output = False
389
+ else:
390
+ reset_multi_scale_output = True
391
+ modules.append(
392
+ HighResolutionModule(num_branches,
393
+ block,
394
+ num_blocks,
395
+ num_inchannels,
396
+ num_channels,
397
+ fuse_method,
398
+ reset_multi_scale_output)
399
+ )
400
+ num_inchannels = modules[-1].get_num_inchannels()
401
+
402
+ return nn.Sequential(*modules), num_inchannels
403
+
404
+ def forward(self, x):
405
+ # h, w = x.size(2), x.size(3)
406
+ x = self.conv1(x)
407
+ x_skip = x.clone()
408
+ x = self.bn1(x)
409
+ x = self.relu(x)
410
+ x = self.conv2(x)
411
+ x = self.bn2(x)
412
+ x = self.relu(x)
413
+ x = self.layer1(x)
414
+
415
+ x_list = []
416
+ for i in range(self.stage2_cfg['NUM_BRANCHES']):
417
+ if self.transition1[i] is not None:
418
+ x_list.append(self.transition1[i](x))
419
+ else:
420
+ x_list.append(x)
421
+ y_list = self.stage2(x_list)
422
+
423
+ x_list = []
424
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
425
+ if self.transition2[i] is not None:
426
+ x_list.append(self.transition2[i](y_list[-1]))
427
+ else:
428
+ x_list.append(y_list[i])
429
+ y_list = self.stage3(x_list)
430
+
431
+ x_list = []
432
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
433
+ if self.transition3[i] is not None:
434
+ x_list.append(self.transition3[i](y_list[-1]))
435
+ else:
436
+ x_list.append(y_list[i])
437
+ x = self.stage4(x_list)
438
+
439
+ # Head Part
440
+ height, width = x[0].size(2), x[0].size(3)
441
+ x1 = F.interpolate(x[1], size=(height, width), mode='bilinear', align_corners=False)
442
+ x2 = F.interpolate(x[2], size=(height, width), mode='bilinear', align_corners=False)
443
+ x3 = F.interpolate(x[3], size=(height, width), mode='bilinear', align_corners=False)
444
+ x = torch.cat([x[0], x1, x2, x3], 1)
445
+ x = self._make_head(x, x_skip)
446
+
447
+ return x
448
+
449
+ def init_weights(self, pretrained=''):
450
+ logger.info('=> init weights from normal distribution')
451
+ for m in self.modules():
452
+ if isinstance(m, nn.Conv2d):
453
+ # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
454
+ nn.init.normal_(m.weight, std=0.001)
455
+ # nn.init.constant_(m.bias, 0)
456
+ elif isinstance(m, nn.BatchNorm2d):
457
+ nn.init.constant_(m.weight, 1)
458
+ nn.init.constant_(m.bias, 0)
459
+ if os.path.isfile(pretrained):
460
+ pretrained_dict = torch.load(pretrained)
461
+ logger.info('=> loading pretrained model {}'.format(pretrained))
462
+ print('=> loading pretrained model {}'.format(pretrained))
463
+ model_dict = self.state_dict()
464
+ pretrained_dict = {k: v for k, v in pretrained_dict.items()
465
+ if k in model_dict.keys()}
466
+ for k, _ in pretrained_dict.items():
467
+ logger.info(
468
+ '=> loading {} pretrained model {}'.format(k, pretrained))
469
+ #print('=> loading {} pretrained model {}'.format(k, pretrained))
470
+ model_dict.update(pretrained_dict)
471
+ self.load_state_dict(model_dict)
472
+
473
+
474
+ def get_cls_net(config, pretrained='', **kwargs):
475
+ model = HighResolutionNet(config, **kwargs)
476
+ model.init_weights(pretrained)
477
+ return model
478
+
model/dataloader.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import sys
4
+ import glob
5
+ import json
6
+ import torch
7
+ import numpy as np
8
+ import torchvision.transforms as T
9
+ import torchvision.transforms.functional as f
10
+
11
+ from torchvision.transforms import v2
12
+ from torch.utils.data import Dataset
13
+ from PIL import Image
14
+
15
+ from utils.utils_keypoints import KeypointsDB
16
+ from utils.utils_keypointsWC import KeypointsWCDB
17
+
18
+
19
+ class SoccerNetCalibrationDataset(Dataset):
20
+
21
+ def __init__(self, root_dir, split, transform, main_cam_only=True):
22
+
23
+ self.root_dir = root_dir
24
+ self.split = split
25
+ self.transform = transform
26
+
27
+ self.match_info = json.load(open(root_dir + split + '/match_info.json'))
28
+ self.files = self.get_image_files(rate=1)
29
+
30
+ if main_cam_only:
31
+ self.get_main_camera()
32
+
33
+ def get_image_files(self, rate=3):
34
+ files = glob.glob(os.path.join(self.root_dir + self.split, "*.jpg"))
35
+ files.sort()
36
+ if rate > 1:
37
+ files = files[::rate]
38
+ return files
39
+
40
+
41
+ def __len__(self):
42
+ return len(self.files)
43
+
44
+ def __getitem__(self, idx):
45
+ if torch.is_tensor(idx):
46
+ idx = idx.tolist()
47
+
48
+ img_name = self.files[idx]
49
+ image = Image.open(img_name)
50
+ data = json.load(open(img_name.split('.')[0] + ".json"))
51
+ data = self.correct_labels(data)
52
+ sample = self.transform({'image': image, 'data': data})
53
+
54
+ img_db = KeypointsDB(sample['data'], sample['image'])
55
+ target, mask = img_db.get_tensor_w_mask()
56
+ image = sample['image']
57
+
58
+ return image, torch.from_numpy(target).float(), torch.from_numpy(mask).float()
59
+
60
+
61
+ def get_main_camera(self):
62
+ self.files = [file for file in self.files if int(self.match_info[file.split('/')[-1]]['ms_time']) == \
63
+ int(self.match_info[file.split('/')[-1]]['replay_time'])]
64
+
65
+ def correct_labels(self, data):
66
+ if 'Goal left post left' in data.keys():
67
+ data['Goal left post left '] = copy.deepcopy(data['Goal left post left'])
68
+ del data['Goal left post left']
69
+
70
+ return data
71
+
72
+
73
+ class WorldCup2014Dataset(Dataset):
74
+
75
+ def __init__(self, root_dir, split, transform):
76
+ self.root_dir = root_dir
77
+ self.split = split
78
+ self.transform = transform
79
+ assert self.split in ['train_val', 'test'], f'unknown dataset type {self.split}'
80
+
81
+ self.files = glob.glob(os.path.join(self.root_dir + self.split, "*.jpg"))
82
+ self.homographies = glob.glob(os.path.join(self.root_dir + self.split, "*.homographyMatrix"))
83
+ self.num_samples = len(self.files)
84
+
85
+ self.files.sort()
86
+ self.homographies.sort()
87
+
88
+ def __len__(self):
89
+ return self.num_samples
90
+
91
+ def __getitem__(self, idx):
92
+ image = self.get_image_by_index(idx)
93
+ homography = self.get_homography_by_index(idx)
94
+ img_db = KeypointsWCDB(image, homography, (960,540))
95
+ target, mask = img_db.get_tensor_w_mask()
96
+
97
+ sample = self.transform({'image': image, 'target': target, 'mask': mask})
98
+
99
+ return sample['image'], sample['target'], sample['mask']
100
+
101
+ def convert_homography_WC14GT_to_SN(self, H):
102
+ T = np.eye(3)
103
+ #T[0, -1] = -115 / 2
104
+ #T[1, -1] = -74 / 2
105
+ yard2meter = 0.9144
106
+ S = np.eye(3)
107
+ S[0, 0] = yard2meter
108
+ S[1, 1] = yard2meter
109
+ H_SN = S @ (T @ H)
110
+
111
+ return H_SN
112
+
113
+ def get_image_by_index(self, index):
114
+ img_file = self.files[index]
115
+ image = Image.open(img_file)
116
+ return image
117
+
118
+ def get_homography_by_index(self, index):
119
+ homography_file = self.homographies[index]
120
+ with open(homography_file, 'r') as file:
121
+ lines = file.readlines()
122
+ matrix_elements = []
123
+ for line in lines:
124
+ matrix_elements.extend([float(element) for element in line.split()])
125
+ homography = np.array(matrix_elements).reshape((3, 3))
126
+ homography = self.convert_homography_WC14GT_to_SN(homography)
127
+ homography = torch.from_numpy(homography)
128
+ homography = homography / homography[2:3, 2:3]
129
+ return homography
130
+
131
+
132
+ class TSWorldCupDataset(Dataset):
133
+
134
+ def __init__(self, root_dir, split, transform):
135
+ self.root_dir = root_dir
136
+ self.split = split
137
+ self.transform = transform
138
+ assert self.split in ['train', 'test'], f'unknown dataset type {self.split}'
139
+
140
+ self.files_txt = self.get_txt()
141
+
142
+ self.files = self.get_jpg_files()
143
+ self.homographies = self.get_homographies()
144
+ self.num_samples = len(self.files)
145
+
146
+ self.files.sort()
147
+ self.homographies.sort()
148
+
149
+ def __len__(self):
150
+ return self.num_samples
151
+
152
+ def __getitem__(self, idx):
153
+ image = self.get_image_by_index(idx)
154
+ homography = self.get_homography_by_index(idx)
155
+ img_db = KeypointsWCDB(image, homography, (960,540))
156
+ target, mask = img_db.get_tensor_w_mask()
157
+
158
+ sample = self.transform({'image': image, 'target': target, 'mask': mask})
159
+
160
+ return sample['image'], sample['target'], sample['mask']
161
+
162
+
163
+ def get_txt(self):
164
+ with open(self.root_dir + self.split + '.txt', 'r') as file:
165
+ lines = file.readlines()
166
+ lines = [line.strip() for line in lines]
167
+ return lines
168
+
169
+ def get_jpg_files(self):
170
+ all_jpg_files = []
171
+ for dir in self.files_txt:
172
+ full_dir = self.root_dir + "Dataset/80_95/" + dir
173
+ jpg_files = []
174
+ for file in os.listdir(full_dir):
175
+ if file.lower().endswith('.jpg') or file.lower().endswith('.jpeg'):
176
+ jpg_files.append(os.path.join(full_dir, file))
177
+
178
+ all_jpg_files.extend(jpg_files)
179
+
180
+ return all_jpg_files
181
+
182
+ def get_homographies(self):
183
+ all_homographies = []
184
+ for dir in self.files_txt:
185
+ full_dir = self.root_dir + "Annotations/80_95/" + dir
186
+ homographies = []
187
+ for file in os.listdir(full_dir):
188
+ if file.lower().endswith('.npy'):
189
+ homographies.append(os.path.join(full_dir, file))
190
+
191
+ all_homographies.extend(homographies)
192
+
193
+ return all_homographies
194
+
195
+
196
+ def convert_homography_WC14GT_to_SN(self, H):
197
+ T = np.eye(3)
198
+ #T[0, -1] = -115 / 2
199
+ #T[1, -1] = -74 / 2
200
+ yard2meter = 0.9144
201
+ S = np.eye(3)
202
+ S[0, 0] = yard2meter
203
+ S[1, 1] = yard2meter
204
+ H_SN = S @ (T @ H)
205
+
206
+ return H_SN
207
+
208
+ def get_image_by_index(self, index):
209
+ img_file = self.files[index]
210
+ image = Image.open(img_file)
211
+ return image
212
+
213
+ def get_homography_by_index(self, index):
214
+ homography_file = self.homographies[index]
215
+ homography = np.load(homography_file)
216
+ homography = self.convert_homography_WC14GT_to_SN(homography)
217
+ homography = torch.from_numpy(homography)
218
+ homography = homography / homography[2:3, 2:3]
219
+ return homography
model/dataloader_l.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import glob
4
+ import json
5
+ import torch
6
+ import numpy as np
7
+
8
+ from torchvision.transforms import v2
9
+ from torch.utils.data import Dataset
10
+ from PIL import Image
11
+
12
+ from utils.utils_lines import LineKeypointsDB
13
+ from utils.utils_linesWC import LineKeypointsWCDB
14
+
15
+
16
+
17
+ class SoccerNetCalibrationDataset(Dataset):
18
+
19
+ def __init__(self, root_dir, split, transform, main_cam_only=True):
20
+
21
+ self.root_dir = root_dir
22
+ self.split = split
23
+ self.transform = transform
24
+
25
+ #self.match_info = json.load(open(root_dir + split + '/match_info.json'))
26
+ self.files = glob.glob(os.path.join(self.root_dir + self.split, "*.jpg"))
27
+
28
+ if main_cam_only:
29
+ self.get_main_camera()
30
+
31
+
32
+ def __len__(self):
33
+ return len(self.files)
34
+
35
+ def __getitem__(self, idx):
36
+ if torch.is_tensor(idx):
37
+ idx = idx.tolist()
38
+
39
+ img_name = self.files[idx]
40
+ image = Image.open(img_name)
41
+ data = json.load(open(img_name.split('.')[0] + ".json"))
42
+ sample = self.transform({'image': image, 'data': data})
43
+
44
+ img_db = LineKeypointsDB(sample['data'], sample['image'])
45
+ target = img_db.get_tensor()
46
+
47
+ return sample['image'], torch.from_numpy(target).float()
48
+
49
+ def get_main_camera(self):
50
+ self.files = [file for file in self.files if int(self.match_info[file.split('/')[-1]]['ms_time']) == \
51
+ int(self.match_info[file.split('/')[-1]]['replay_time'])]
52
+
53
+
54
+ class WorldCup2014Dataset(Dataset):
55
+
56
+ def __init__(self, root_dir, split, transform):
57
+ self.root_dir = root_dir
58
+ self.split = split
59
+ self.transform = transform
60
+ assert self.split in ['train_val', 'test'], f'unknown dataset type {self.split}'
61
+
62
+ self.files = glob.glob(os.path.join(self.root_dir + self.split, "*.jpg"))
63
+ self.homographies = glob.glob(os.path.join(self.root_dir + self.split, "*.homographyMatrix"))
64
+ self.num_samples = len(self.files)
65
+
66
+ self.files.sort()
67
+ self.homographies.sort()
68
+
69
+ def __len__(self):
70
+ return self.num_samples
71
+
72
+ def __getitem__(self, idx):
73
+ image = self.get_image_by_index(idx)
74
+ homography = self.get_homography_by_index(idx)
75
+ img_db = LineKeypointsWCDB(image, homography, (960,540))
76
+ target, mask = img_db.get_tensor_w_mask()
77
+
78
+ sample = self.transform({'image': image, 'target': target, 'mask': mask})
79
+
80
+ return sample['image'], sample['target'], sample['mask']
81
+
82
+ def convert_homography_WC14GT_to_SN(self, H):
83
+ T = np.eye(3)
84
+ #T[0, -1] = -115 / 2
85
+ #T[1, -1] = -74 / 2
86
+ yard2meter = 0.9144
87
+ S = np.eye(3)
88
+ S[0, 0] = yard2meter
89
+ S[1, 1] = yard2meter
90
+ H_SN = S @ (T @ H)
91
+
92
+ return H_SN
93
+
94
+ def get_image_by_index(self, index):
95
+ img_file = self.files[index]
96
+ image = Image.open(img_file)
97
+ return image
98
+
99
+ def get_homography_by_index(self, index):
100
+ homography_file = self.homographies[index]
101
+ with open(homography_file, 'r') as file:
102
+ lines = file.readlines()
103
+ matrix_elements = []
104
+ for line in lines:
105
+ matrix_elements.extend([float(element) for element in line.split()])
106
+ homography = np.array(matrix_elements).reshape((3, 3))
107
+ homography = self.convert_homography_WC14GT_to_SN(homography)
108
+ homography = torch.from_numpy(homography)
109
+ homography = homography / homography[2:3, 2:3]
110
+ return homography
111
+
112
+
113
+
114
+ class TSWorldCupDataset(Dataset):
115
+
116
+ def __init__(self, root_dir, split, transform):
117
+ self.root_dir = root_dir
118
+ self.split = split
119
+ self.transform = transform
120
+ assert self.split in ['train', 'test'], f'unknown dataset type {self.split}'
121
+
122
+ self.files_txt = self.get_txt()
123
+
124
+ self.files = self.get_jpg_files()
125
+ self.homographies = self.get_homographies()
126
+ self.num_samples = len(self.files)
127
+
128
+ self.files.sort()
129
+ self.homographies.sort()
130
+
131
+ def __len__(self):
132
+ return self.num_samples
133
+
134
+ def __getitem__(self, idx):
135
+ image = self.get_image_by_index(idx)
136
+ homography = self.get_homography_by_index(idx)
137
+ img_db = LineKeypointsWCDB(image, homography, (960,540))
138
+ target, mask = img_db.get_tensor_w_mask()
139
+
140
+ sample = self.transform({'image': image, 'target': target, 'mask': mask})
141
+
142
+ return sample['image'], sample['target'], sample['mask']
143
+
144
+
145
+ def get_txt(self):
146
+ with open(self.root_dir + self.split + '.txt', 'r') as file:
147
+ lines = file.readlines()
148
+ lines = [line.strip() for line in lines]
149
+ return lines
150
+
151
+ def get_jpg_files(self):
152
+ all_jpg_files = []
153
+ for dir in self.files_txt:
154
+ full_dir = self.root_dir + "Dataset/80_95/" + dir
155
+ jpg_files = []
156
+ for file in os.listdir(full_dir):
157
+ if file.lower().endswith('.jpg') or file.lower().endswith('.jpeg'):
158
+ jpg_files.append(os.path.join(full_dir, file))
159
+
160
+ all_jpg_files.extend(jpg_files)
161
+
162
+ return all_jpg_files
163
+
164
+ def get_homographies(self):
165
+ all_homographies = []
166
+ for dir in self.files_txt:
167
+ full_dir = self.root_dir + "Annotations/80_95/" + dir
168
+ homographies = []
169
+ for file in os.listdir(full_dir):
170
+ if file.lower().endswith('.npy'):
171
+ homographies.append(os.path.join(full_dir, file))
172
+
173
+ all_homographies.extend(homographies)
174
+
175
+ return all_homographies
176
+
177
+
178
+ def convert_homography_WC14GT_to_SN(self, H):
179
+ T = np.eye(3)
180
+ #T[0, -1] = -115 / 2
181
+ #T[1, -1] = -74 / 2
182
+ yard2meter = 0.9144
183
+ S = np.eye(3)
184
+ S[0, 0] = yard2meter
185
+ S[1, 1] = yard2meter
186
+ H_SN = S @ (T @ H)
187
+
188
+ return H_SN
189
+
190
+ def get_image_by_index(self, index):
191
+ img_file = self.files[index]
192
+ image = Image.open(img_file)
193
+ return image
194
+
195
+ def get_homography_by_index(self, index):
196
+ homography_file = self.homographies[index]
197
+ homography = np.load(homography_file)
198
+ homography = self.convert_homography_WC14GT_to_SN(homography)
199
+ homography = torch.from_numpy(homography)
200
+ homography = homography / homography[2:3, 2:3]
201
+ return homography
202
+
model/losses.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class MSELoss(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.criterion = nn.MSELoss(reduction='none')
9
+
10
+ def forward(self, output, target, mask=None):
11
+ loss = self.criterion(output, target)
12
+ if mask is not None:
13
+ loss = (loss * mask).mean()
14
+ else:
15
+ loss = (loss).mean()
16
+ return loss
17
+
18
+ class KLDivLoss(nn.Module):
19
+ def __init__(self):
20
+ super().__init__()
21
+ self.criterion = nn.KLDivLoss(reduction='batchmean')
22
+
23
+ def forward(self, output, target, mask=None):
24
+ if mask is not None:
25
+ output_masked = output * mask
26
+ target_masked = target * mask
27
+ loss = self.criterion(F.log_softmax(output_masked), target_masked)
28
+ else:
29
+ loss = self.criterion(F.log_softmax(output), target)
30
+ return loss
31
+
32
+ class HeatmapWeightingMSELoss(nn.Module):
33
+
34
+ def __init__(self):
35
+ super().__init__()
36
+ self.criterion = nn.MSELoss(reduction='none')
37
+
38
+ def forward(self, output, target, mask=None):
39
+ """Forward function."""
40
+ batch_size = output.size(0)
41
+ num_joints = output.size(1)
42
+
43
+ heatmaps_pred = output.reshape(
44
+ (batch_size, num_joints, -1)).split(1, 1)
45
+ heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
46
+
47
+ loss = 0.
48
+
49
+ for idx in range(num_joints):
50
+ heatmap_pred = heatmaps_pred[idx].squeeze(1)
51
+ heatmap_gt = heatmaps_gt[idx].squeeze(1)
52
+ """
53
+ Set different weight generation functions.
54
+ weight = heatmap_gt + 1
55
+ weight = heatmap_gt * 2 + 1
56
+ weight = heatmap_gt * heatmap_gt + 1
57
+ weight = torch.exp(heatmap_gt + 1)
58
+ """
59
+
60
+ if mask is not None:
61
+ #weight = heatmap_gt * mask[:, idx] + 1
62
+ weight = torch.exp(heatmap_gt * mask[:, idx] + 1)
63
+ loss += torch.mean(self.criterion(heatmap_pred * mask[:, idx],
64
+ heatmap_gt * mask[:, idx]) * weight)
65
+ else:
66
+ weight = heatmap_gt + 1
67
+ loss += torch.mean(self.criterion(heatmap_pred, heatmap_gt) * weight)
68
+ return loss / (num_joints+1)
69
+
70
+
71
+ class CombMSEAW(nn.Module):
72
+ def __init__(self, lambda1=1, lambda2=1, alpha=2.1, omega=14, epsilon=1, theta=0.5):
73
+ super().__init__()
74
+ # Adaptive wing loss
75
+ self.lambda1 = lambda1
76
+ self.lambda2 = lambda2
77
+ self.criterion1 = nn.MSELoss(reduction='none')
78
+ self.alpha = alpha
79
+ self.omega = omega
80
+ self.epsilon = epsilon
81
+ self.theta = theta
82
+
83
+
84
+ def forward(self, pred, target, mask=None):
85
+ loss = 0
86
+ if mask is not None:
87
+ pred_masked, target_masked = pred * mask, target * mask
88
+ loss += self.lambda1 * self.criterion1(pred_masked, target_masked)
89
+ loss += self.lambda2 * self.adaptive_wing(pred_masked, target_masked)
90
+ else:
91
+ loss += self.lambda1 * self.criterion1(pred, target)
92
+ loss += self.lambda2 * self.adaptive_wing(pred, target)
93
+ return torch.mean(loss)
94
+
95
+ def adaptive_wing(self, pred, target):
96
+ delta = (target - pred).abs()
97
+ alpha_t = self.alpha - target
98
+ A = self.omega * (
99
+ 1 / (1 + torch.pow(self.theta / self.epsilon,
100
+ alpha_t))) * alpha_t \
101
+ * (torch.pow(self.theta / self.epsilon,
102
+ self.alpha - target - 1)) * (1 / self.epsilon)
103
+ C = self.theta * A - self.omega * torch.log(
104
+ 1 + torch.pow(self.theta / self.epsilon, alpha_t))
105
+
106
+ losses = torch.where(delta < self.theta,
107
+ self.omega * torch.log(
108
+ 1 + torch.pow(delta / self.epsilon, alpha_t)),
109
+ A * delta - C)
110
+ return losses
111
+
112
+
113
+
114
+ class AdaptiveWingLoss(nn.Module):
115
+ def __init__(self, alpha=2.1, omega=14, epsilon=1, theta=0.5):
116
+ super().__init__()
117
+ # Adaptive wing loss
118
+ self.alpha = alpha
119
+ self.omega = omega
120
+ self.epsilon = epsilon
121
+ self.theta = theta
122
+
123
+ def forward(self, pred, target, mask=None):
124
+ if mask is not None:
125
+ pred_masked, target_masked = pred * mask, target * mask
126
+ loss = self.adaptive_wing(pred_masked, target_masked)
127
+ else:
128
+ loss = self.adaptive_wing(pred, target)
129
+ return loss
130
+
131
+ def adaptive_wing(self, pred, target):
132
+ delta = (target - pred).abs()
133
+ alpha_t = self.alpha - target
134
+ A = self.omega * (
135
+ 1 / (1 + torch.pow(self.theta / self.epsilon,
136
+ alpha_t))) * alpha_t \
137
+ * (torch.pow(self.theta / self.epsilon,
138
+ self.alpha - target - 1)) * (1 / self.epsilon)
139
+ C = self.theta * A - self.omega * torch.log(
140
+ 1 + torch.pow(self.theta / self.epsilon, alpha_t))
141
+
142
+ losses = torch.where(delta < self.theta,
143
+ self.omega * torch.log(
144
+ 1 + torch.pow(delta / self.epsilon, alpha_t)),
145
+ A * delta - C)
146
+ return torch.mean(losses)
147
+
148
+ class GaussianFocalLoss(nn.Module):
149
+ """GaussianFocalLoss is a variant of focal loss.
150
+ More details can be found in the `paper
151
+ <https://arxiv.org/abs/1808.01244>`_
152
+ Code is modified from `kp_utils.py
153
+ <https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/kp_utils.py#L152>`_ # noqa: E501
154
+ Please notice that the target in GaussianFocalLoss is a gaussian heatmap,
155
+ not 0/1 binary target.
156
+ Args:
157
+ alpha (float): Power of prediction.
158
+ gamma (float): Power of target for negative samples.
159
+ reduction (str): Options are "none", "mean" and "sum".
160
+ loss_weight (float): Loss weight of current loss.
161
+ """
162
+
163
+ def __init__(self,
164
+ alpha=2.0,
165
+ gamma=4.0,
166
+ reduction='mean',
167
+ loss_weight=1.0):
168
+ super(GaussianFocalLoss, self).__init__()
169
+ self.alpha = alpha
170
+ self.gamma = gamma
171
+ self.reduction = reduction
172
+ self.loss_weight = loss_weight
173
+
174
+ def forward(self,
175
+ pred,
176
+ target,
177
+ mask=None,
178
+ weight=None,
179
+ avg_factor=None,
180
+ reduction_override=None):
181
+ """Forward function.
182
+ Args:
183
+ pred (torch.Tensor): The prediction.
184
+ target (torch.Tensor): The learning target of the prediction
185
+ in gaussian distribution.
186
+ weight (torch.Tensor, optional): The weight of loss for each
187
+ prediction. Defaults to None.
188
+ avg_factor (int, optional): Average factor that is used to average
189
+ the loss. Defaults to None.
190
+ reduction_override (str, optional): The reduction method used to
191
+ override the original reduction method of the loss.
192
+ Defaults to None.
193
+ """
194
+ assert reduction_override in (None, 'none', 'mean', 'sum')
195
+ reduction = (reduction_override if reduction_override else self.reduction)
196
+ if mask is not None:
197
+ pred_masked, target_masked = pred * mask, target * mask
198
+ loss_reg = self.loss_weight * self.gaussian_focal_loss(pred_masked, target_masked, alpha=self.alpha,
199
+ gamma=self.gamma)
200
+ else:
201
+ loss_reg = self.loss_weight * self.gaussian_focal_loss(pred, target, alpha=self.alpha, gamma=self.gamma)
202
+ return loss_reg.mean()
203
+
204
+ def gaussian_focal_loss(self, pred, gaussian_target, alpha=2.0, gamma=4.0):
205
+ """`Focal Loss <https://arxiv.org/abs/1708.02002>`_ for targets in gaussian
206
+ distribution.
207
+ Args:
208
+ pred (torch.Tensor): The prediction.
209
+ gaussian_target (torch.Tensor): The learning target of the prediction
210
+ in gaussian distribution.
211
+ alpha (float, optional): A balanced form for Focal Loss.
212
+ Defaults to 2.0.
213
+ gamma (float, optional): The gamma for calculating the modulating
214
+ factor. Defaults to 4.0.
215
+ """
216
+ eps = 1e-12
217
+ pos_weights = gaussian_target.eq(1)
218
+ neg_weights = (1 - gaussian_target).pow(gamma)
219
+ pos_loss = -(pred + eps).log() * (1 - pred).pow(alpha) * pos_weights
220
+ neg_loss = -(1 - pred + eps).log() * pred.pow(alpha) * neg_weights
221
+ return pos_loss + neg_loss
model/metrics.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import scipy
4
+ import random
5
+ import numpy as np
6
+
7
+ from shapely.geometry import Point, Polygon, MultiPoint
8
+
9
+ def calculate_metrics(gt, pred, mask, conf_th=0.1, dist_th=5):
10
+ geometry_mask = (mask[:, :-1] > 0).cpu()
11
+
12
+ pred_mask = torch.all((pred[:, :, :, -1] > conf_th), dim=-1)
13
+ gt_mask = torch.all((gt[:, :, :, -1] > conf_th), dim=-1)
14
+
15
+ pred_pos = pred[geometry_mask][:, 0, :]
16
+ pred_mask = pred_mask[geometry_mask]
17
+ gt_pos = gt[geometry_mask][:, 0, :]
18
+ gt_mask = gt_mask[geometry_mask]
19
+
20
+ distances = torch.norm(pred_pos - gt_pos, dim=1)
21
+
22
+ # Count true positives, false positives, and false negatives based on distance threshold
23
+ true_positives = ((distances < dist_th) & pred_mask & gt_mask).sum().item()
24
+ true_negatives = (~pred_mask & ~gt_mask).sum().item()
25
+ false_positives = ((pred_mask & ~gt_mask) | ((distances >= dist_th) & pred_mask & gt_mask)).sum().item()
26
+ false_negatives = (~pred_mask & gt_mask).sum().item()
27
+
28
+ # Calculate precision, recall, and F1 score
29
+ accuracy = (true_positives + true_negatives) / geometry_mask.sum().item()
30
+ precision = true_positives / (true_positives + false_positives + 1e-10)
31
+ recall = true_positives / (true_positives + false_negatives + 1e-10)
32
+ f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
33
+
34
+ return accuracy, precision, recall, f1
35
+
36
+
37
+ def calculate_metrics_l(gt, pred, conf_th=0.1, dist_th=5):
38
+
39
+ pred_pos = pred[:, :, :, :-1]
40
+ gt_pos = gt[:, :, :, :-1]
41
+
42
+ pred_mask = torch.all((pred[:, :, :, -1] > conf_th), dim=-1)
43
+ gt_mask = torch.all((gt[:, :, :, -1] > conf_th), dim=-1)
44
+
45
+ gt_flip = torch.flip(gt_pos, dims=[2])
46
+
47
+ distances1 = torch.norm(pred_pos - gt_pos, dim=-1)
48
+ distances2 = torch.norm(pred_pos - gt_flip, dim=-1)
49
+
50
+ distances1_bool = torch.all((distances1 < dist_th), dim=-1)
51
+ distances2_bool = torch.all((distances2 < dist_th), dim=-1)
52
+
53
+ # Count true positives, false positives, and false negatives based on distance threshold
54
+ true_positives = ((distances1_bool | distances2_bool) & pred_mask & gt_mask).sum().item()
55
+ true_negatives = (~pred_mask & ~gt_mask).sum().item()
56
+ false_positives = (
57
+ (pred_mask & ~gt_mask) | ((~distances1_bool & ~distances2_bool) & pred_mask & gt_mask)).sum().item()
58
+ false_negatives = (~pred_mask & gt_mask).sum().item()
59
+
60
+ # Calculate precision, recall, and F1 score
61
+ accuracy = (true_positives + true_negatives) / (gt.size()[1] * gt.size()[0])
62
+ precision = true_positives / (true_positives + false_positives + 1e-10)
63
+ recall = true_positives / (true_positives + false_negatives + 1e-10)
64
+ f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
65
+
66
+ return accuracy, precision, recall, f1
67
+
68
+
69
+ def calculate_metrics_l_with_mask(gt, pred, mask, conf_th=0.1, dist_th=5):
70
+
71
+ #only works with batch 1. Should be adapted to batch > 1 in an organic way or just do a loop over batch
72
+
73
+ geometry_mask = (mask[:, :-1] > 0).cpu()
74
+
75
+ pred = pred[geometry_mask]
76
+ gt = gt[geometry_mask]
77
+
78
+ pred_pos = pred[:, :, :-1]
79
+ gt_pos = gt[:, :, :-1]
80
+
81
+ pred_mask = torch.all((pred[:, :, -1] > conf_th), dim=-1)
82
+ gt_mask = torch.all((gt[:, :, -1] > conf_th), dim=-1)
83
+
84
+ gt_flip = torch.flip(gt_pos, dims=[1])
85
+
86
+ distances1 = torch.norm(pred_pos - gt_pos, dim=-1)
87
+ distances2 = torch.norm(pred_pos - gt_flip, dim=-1)
88
+
89
+ distances1_bool = torch.all((distances1 < dist_th), dim=-1)
90
+ distances2_bool = torch.all((distances2 < dist_th), dim=-1)
91
+
92
+ # Count true positives, false positives, and false negatives based on distance threshold
93
+ true_positives = ((distances1_bool | distances2_bool) & pred_mask & gt_mask).sum().item()
94
+ true_negatives = (~pred_mask & ~gt_mask).sum().item()
95
+ false_positives = (
96
+ (pred_mask & ~gt_mask) | ((~distances1_bool & ~distances2_bool) & pred_mask & gt_mask)).sum().item()
97
+ false_negatives = (~pred_mask & gt_mask).sum().item()
98
+
99
+ # Calculate precision, recall, and F1 score
100
+ accuracy = (true_positives + true_negatives) / geometry_mask.sum().item()
101
+ precision = true_positives / (true_positives + false_positives + 1e-10)
102
+ recall = true_positives / (true_positives + false_negatives + 1e-10)
103
+ f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
104
+
105
+ return accuracy, precision, recall, f1
106
+
107
+
108
+ def calc_iou_whole_with_poly(pred_h, gt_h, frame_w=1280, frame_h=720, template_w=115, template_h=74):
109
+
110
+ corners = np.array([[0, 0],
111
+ [frame_w - 1, 0],
112
+ [frame_w - 1, frame_h - 1],
113
+ [0, frame_h - 1]], dtype=np.float64)
114
+
115
+ mapping_mat = np.linalg.inv(gt_h)
116
+ mapping_mat /= mapping_mat[2, 2]
117
+
118
+ gt_corners = cv2.perspectiveTransform(
119
+ corners[:, None, :], gt_h) # inv_gt_mat * (gt_mat * [x, y, 1])
120
+ gt_corners = cv2.perspectiveTransform(
121
+ gt_corners, np.linalg.inv(gt_h))
122
+ gt_corners = gt_corners[:, 0, :]
123
+
124
+ pred_corners = cv2.perspectiveTransform(
125
+ corners[:, None, :], gt_h) # inv_pred_mat * (gt_mat * [x, y, 1])
126
+ pred_corners = cv2.perspectiveTransform(
127
+ pred_corners, np.linalg.inv(pred_h))
128
+ pred_corners = pred_corners[:, 0, :]
129
+
130
+ gt_poly = Polygon(gt_corners.tolist())
131
+ pred_poly = Polygon(pred_corners.tolist())
132
+
133
+ # f, axarr = plt.subplots(1, 2, figsize=(16, 12))
134
+ # axarr[0].plot(*gt_poly.exterior.coords.xy)
135
+ # axarr[1].plot(*pred_poly.exterior.coords.xy)
136
+ # plt.show()
137
+
138
+ if pred_poly.is_valid is False:
139
+ return 0., None, None
140
+
141
+ if not gt_poly.intersects(pred_poly):
142
+ print('not intersects')
143
+ iou = 0.
144
+ else:
145
+ intersection = gt_poly.intersection(pred_poly).area
146
+ union = gt_poly.area + pred_poly.area - intersection
147
+ if union <= 0.:
148
+ print('whole union', union)
149
+ iou = 0.
150
+ else:
151
+ iou = intersection / union
152
+
153
+ return iou, None, None
154
+
155
+ def calc_iou_part(pred_h, gt_h, frame_w=1280, frame_h=720, template_w=115, template_h=74):
156
+
157
+ # field template binary mask
158
+ field_mask = np.ones((frame_h, frame_w, 3), dtype=np.uint8) * 255
159
+ gt_mask = cv2.warpPerspective(field_mask, gt_h, (template_w, template_h),
160
+ cv2.INTER_AREA, borderMode=cv2.BORDER_CONSTANT, borderValue=(0))
161
+
162
+ pred_mask = cv2.warpPerspective(field_mask, pred_h, (template_w, template_h),
163
+ cv2.INTER_AREA, borderMode=cv2.BORDER_CONSTANT, borderValue=(0))
164
+
165
+ gt_mask[gt_mask > 0] = 255
166
+ pred_mask[pred_mask > 0] = 255
167
+
168
+ intersection = ((gt_mask > 0) * (pred_mask > 0)).sum()
169
+ union = (gt_mask > 0).sum() + (pred_mask > 0).sum() - intersection
170
+
171
+ if union <= 0:
172
+ print('part union', union)
173
+ # iou = float('nan')
174
+ iou = 0.
175
+ else:
176
+ iou = float(intersection) / float(union)
177
+
178
+ # === blending ===
179
+ gt_white_area = (gt_mask[:, :, 0] == 255) & (
180
+ gt_mask[:, :, 1] == 255) & (gt_mask[:, :, 2] == 255)
181
+ gt_fill = gt_mask.copy()
182
+ gt_fill[gt_white_area, 0] = 255
183
+ gt_fill[gt_white_area, 1] = 0
184
+ gt_fill[gt_white_area, 2] = 0
185
+ pred_white_area = (pred_mask[:, :, 0] == 255) & (
186
+ pred_mask[:, :, 1] == 255) & (pred_mask[:, :, 2] == 255)
187
+ pred_fill = pred_mask.copy()
188
+ pred_fill[pred_white_area, 0] = 0
189
+ pred_fill[pred_white_area, 1] = 255
190
+ pred_fill[pred_white_area, 2] = 0
191
+ gt_maskf = gt_fill.astype(float) / 255
192
+ pred_maskf = pred_fill.astype(float) / 255
193
+ fill_resultf = cv2.addWeighted(gt_maskf, 0.5,
194
+ pred_maskf, 0.5, 0.0)
195
+ fill_result = np.uint8(fill_resultf * 255)
196
+
197
+ return iou
198
+
199
+ def calc_proj_error(pred_h, gt_h, frame_w=1280, frame_h=720, template_w=115, template_h=74):
200
+
201
+ field_mask = np.ones((template_h, template_w, 3), dtype=np.uint8) * 255
202
+ gt_mask = cv2.warpPerspective(field_mask, np.linalg.inv(
203
+ gt_h), (frame_w, frame_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(0))
204
+ gt_gray = cv2.cvtColor(gt_mask, cv2.COLOR_BGR2GRAY)
205
+ contours, hierarchy = cv2.findContours(
206
+ gt_gray, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
207
+ contour = np.squeeze(contours[0])
208
+ poly = Polygon(contour)
209
+ sample_pts = []
210
+ num_pts = 2500
211
+ while len(sample_pts) <= num_pts:
212
+ x = random.sample(range(0, frame_w), 1)
213
+ y = random.sample(range(0, frame_h), 1)
214
+ p = Point(x[0], y[0])
215
+ if p.within(poly):
216
+ sample_pts.append([x[0], y[0]])
217
+ sample_pts = np.array(sample_pts, dtype=np.float32)
218
+
219
+ field_dim_x, field_dim_y = 100, 60
220
+ x_scale = field_dim_x / template_w
221
+ y_scale = field_dim_y / template_h
222
+ scaling_mat = np.eye(3)
223
+ scaling_mat[0, 0] = x_scale
224
+ scaling_mat[1, 1] = y_scale
225
+ gt_temp_grid = cv2.perspectiveTransform(
226
+ sample_pts.reshape(-1, 1, 2), scaling_mat @ gt_h)
227
+ gt_temp_grid = gt_temp_grid.reshape(-1, 2)
228
+ pred_temp_grid = cv2.perspectiveTransform(
229
+ sample_pts.reshape(-1, 1, 2), scaling_mat @ pred_h)
230
+ pred_temp_grid = pred_temp_grid.reshape(-1, 2)
231
+
232
+ # TODO compute distance in top view
233
+ gt_grid_list = []
234
+ pred_grid_list = []
235
+ for gt_pts, pred_pts in zip(gt_temp_grid, pred_temp_grid):
236
+ if 0 <= gt_pts[0] < field_dim_x and 0 <= gt_pts[1] < field_dim_y and \
237
+ 0 <= pred_pts[0] < field_dim_x and 0 <= pred_pts[1] < field_dim_y:
238
+ gt_grid_list.append(gt_pts)
239
+ pred_grid_list.append(pred_pts)
240
+ gt_grid_list = np.array(gt_grid_list)
241
+ pred_grid_list = np.array(pred_grid_list)
242
+
243
+ if gt_grid_list.shape != pred_grid_list.shape:
244
+ print('proj error:', gt_grid_list.shape, pred_grid_list.shape)
245
+ assert gt_grid_list.shape == pred_grid_list.shape, 'shape mismatch'
246
+
247
+ if gt_grid_list.size != 0 and pred_grid_list.size != 0:
248
+ distance_list = calc_euclidean_distance(
249
+ gt_grid_list, pred_grid_list, axis=1)
250
+ return distance_list.mean() # average all keypoints
251
+ else:
252
+ print(gt_grid_list)
253
+ print(pred_grid_list)
254
+ return float('nan')
255
+
256
+ def calc_euclidean_distance(a, b, _norm=np.linalg.norm, axis=None):
257
+ return _norm(a - b, axis=axis)
258
+
259
+ def gen_template_grid():
260
+ # === set uniform grid ===
261
+ # field_dim_x, field_dim_y = 105.000552, 68.003928 # in meter
262
+ field_dim_x, field_dim_y = 114.83, 74.37 # in yard
263
+ # field_dim_x, field_dim_y = 115, 74 # in yard
264
+ nx, ny = (13, 7)
265
+ x = np.linspace(0, field_dim_x, nx)
266
+ y = np.linspace(0, field_dim_y, ny)
267
+ xv, yv = np.meshgrid(x, y, indexing='ij')
268
+ uniform_grid = np.stack((xv, yv), axis=2).reshape(-1, 2)
269
+ uniform_grid = np.concatenate((uniform_grid, np.ones(
270
+ (uniform_grid.shape[0], 1))), axis=1) # top2bottom, left2right
271
+ # TODO: class label in template, each keypoints is (x, y, c), c is label that starts from 1
272
+ for idx, pts in enumerate(uniform_grid):
273
+ pts[2] = idx + 1 # keypoints label
274
+ return uniform_grid
275
+
276
+ def calc_reproj_error(pred_h, gt_h, frame_w=1280, frame_h=720, template_w=115, template_h=74):
277
+
278
+ uniform_grid = gen_template_grid() # grid shape (91, 3), (x, y, label)
279
+ template_grid = uniform_grid[:, :2].copy()
280
+ template_grid = template_grid.reshape(-1, 1, 2)
281
+
282
+ gt_warp_grid = cv2.perspectiveTransform(template_grid, np.linalg.inv(gt_h))
283
+ gt_warp_grid = gt_warp_grid.reshape(-1, 2)
284
+ pred_warp_grid = cv2.perspectiveTransform(
285
+ template_grid, np.linalg.inv(pred_h))
286
+ pred_warp_grid = pred_warp_grid.reshape(-1, 2)
287
+
288
+ # TODO compute distance in camera view
289
+ gt_grid_list = []
290
+ pred_grid_list = []
291
+ for gt_pts, pred_pts in zip(gt_warp_grid, pred_warp_grid):
292
+ if 0 <= gt_pts[0] < frame_w and 0 <= gt_pts[1] < frame_h and \
293
+ 0 <= pred_pts[0] < frame_w and 0 <= pred_pts[1] < frame_h:
294
+ gt_grid_list.append(gt_pts)
295
+ pred_grid_list.append(pred_pts)
296
+ gt_grid_list = np.array(gt_grid_list)
297
+ pred_grid_list = np.array(pred_grid_list)
298
+
299
+ if gt_grid_list.shape != pred_grid_list.shape:
300
+ print('reproj error:', gt_grid_list.shape, pred_grid_list.shape)
301
+ assert gt_grid_list.shape == pred_grid_list.shape, 'shape mismatch'
302
+
303
+ if gt_grid_list.size != 0 and pred_grid_list.size != 0:
304
+ distance_list = calc_euclidean_distance(
305
+ gt_grid_list, pred_grid_list, axis=1)
306
+ distance_list /= frame_h # normalize by image height
307
+ return distance_list.mean() # average all keypoints
308
+ else:
309
+ print(gt_grid_list)
310
+ print(pred_grid_list)
311
+ return float('nan')
312
+
model/transforms.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numbers
4
+ import warnings
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ import torchvision.transforms.functional as f
8
+ import torchvision.transforms as T
9
+ import torchvision.transforms.v2 as v2
10
+
11
+ from torchvision.transforms.functional import _interpolation_modes_from_int, InterpolationMode
12
+ from torchvision import transforms as _transforms
13
+ from typing import List, Optional, Tuple, Union
14
+ from scipy import ndimage
15
+ from torch import Tensor
16
+
17
+ from sn_calibration.src.evaluate_extremities import mirror_labels
18
+
19
+
20
+
21
+ class ToTensor(torch.nn.Module):
22
+ def __call__(self, sample):
23
+ image = sample['image']
24
+
25
+
26
+ return {'image': f.to_tensor(image).float(),
27
+ 'data': sample['data']}
28
+
29
+ def __repr__(self) -> str:
30
+ return f"{self.__class__.__name__}()"
31
+
32
+
33
+ class Normalize(torch.nn.Module):
34
+ def __init__(self, mean, std):
35
+ super().__init__()
36
+ self.mean = mean
37
+ self.std = std
38
+
39
+ def forward(self, sample):
40
+ image = sample['image']
41
+ image = f.normalize(image, self.mean, self.std)
42
+
43
+ return {'image': image,
44
+ 'data': sample['data']}
45
+
46
+
47
+ def __repr__(self) -> str:
48
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
49
+
50
+
51
+ FLIP_POSTS = {
52
+ 'Goal left post right': 'Goal left post left ',
53
+ 'Goal left post left ': 'Goal left post right',
54
+ 'Goal right post right': 'Goal right post left',
55
+ 'Goal right post left': 'Goal right post right'
56
+ }
57
+
58
+ h_lines = ['Goal left crossbar', 'Side line left', 'Small rect. left main', 'Big rect. left main', 'Middle line',
59
+ 'Big rect. right main', 'Small rect. right main', 'Side line right', 'Goal right crossbar']
60
+ v_lines = ['Side line top', 'Big rect. left top', 'Small rect. left top', 'Small rect. left bottom',
61
+ 'Big rect. left bottom', 'Big rect. right top', 'Small rect. right top', 'Small rect. right bottom',
62
+ 'Big rect. right bottom', 'Side line bottom']
63
+
64
+ def swap_top_bottom_names(line_name: str) -> str:
65
+ x: str = 'top'
66
+ y: str = 'bottom'
67
+ if x in line_name or y in line_name:
68
+ return y.join(part.replace(y, x) for part in line_name.split(x))
69
+ return line_name
70
+
71
+
72
+ def swap_posts_names(line_name: str) -> str:
73
+ if line_name in FLIP_POSTS:
74
+ return FLIP_POSTS[line_name]
75
+ return line_name
76
+
77
+
78
+ def flip_annot_names(annot, swap_top_bottom: bool = True,
79
+ swap_posts: bool = True):
80
+ annot = mirror_labels(annot)
81
+ if swap_top_bottom:
82
+ annot = {swap_top_bottom_names(k): v for k, v in annot.items()}
83
+ if swap_posts:
84
+ annot = {swap_posts_names(k): v for k, v in annot.items()}
85
+ return annot
86
+
87
+
88
+ class RandomHorizontalFlip(torch.nn.Module):
89
+ def __init__(self, p=0.5):
90
+ super().__init__()
91
+ self.p = p
92
+
93
+ def forward(self, sample):
94
+ if torch.rand(1) < self.p:
95
+ image, data = sample['image'], sample['data']
96
+ image = f.hflip(image)
97
+ data = flip_annot_names(data)
98
+ for line in data:
99
+ for point in data[line]:
100
+ point['x'] = 1.0 - point['x']
101
+
102
+ return {'image': image,
103
+ 'data': data}
104
+ else:
105
+ return {'image': sample['image'],
106
+ 'data': sample['data']}
107
+
108
+ def __repr__(self) -> str:
109
+ return f"{self.__class__.__name__}(p={self.p})"
110
+
111
+
112
+ class LRAmbiguityFix(torch.nn.Module):
113
+ def __init__(self, v_th, h_th):
114
+ super().__init__()
115
+ self.v_th = v_th
116
+ self.h_th = h_th
117
+
118
+ def forward(self, sample):
119
+ data = sample['data']
120
+
121
+ if len(data) == 0:
122
+ return {'image': sample['image'],
123
+ 'data': sample['data']}
124
+
125
+ n_left, n_right = self.compute_n_sides(data)
126
+
127
+ angles_v, angles_h = [], []
128
+ for line in data.keys():
129
+ line_points = []
130
+ for point in data[line]:
131
+ line_points.append((point['x'], point['y']))
132
+
133
+ sorted_points = sorted(line_points, key=lambda point: (point[0], point[1]))
134
+ pi, pf = sorted_points[0], sorted_points[-1]
135
+ if line in h_lines:
136
+ angle_h = self.calculate_angle_h(pi[0], pi[1], pf[0], pf[1])
137
+ if angle_h:
138
+ angles_h.append(abs(angle_h))
139
+ if line in v_lines:
140
+ angle_v = self.calculate_angle_v(pi[0], pi[1], pf[0], pf[1])
141
+ if angle_v:
142
+ angles_v.append(abs(angle_v))
143
+
144
+
145
+ if len(angles_h) > 0 and len(angles_v) > 0:
146
+ if np.mean(angles_h) < self.h_th and np.mean(angles_v) < self.v_th:
147
+ if n_right > n_left:
148
+ data = flip_annot_names(data, swap_top_bottom=False, swap_posts=False)
149
+
150
+ return {'image': sample['image'],
151
+ 'data': data}
152
+
153
+ def calculate_angle_h(self, x1, y1, x2, y2):
154
+ if not x2 - x1 == 0:
155
+ slope = (y2 - y1) / (x2 - x1)
156
+ angle = math.atan(slope)
157
+ angle_degrees = math.degrees(angle)
158
+ return angle_degrees
159
+ else:
160
+ return None
161
+ def calculate_angle_v(self, x1, y1, x2, y2):
162
+ if not x2 - x1 == 0:
163
+ slope = (y2 - y1) / (x2 - x1)
164
+ angle = math.atan(1 / slope) if slope != 0 else math.pi / 2 # Avoid division by zero
165
+ angle_degrees = math.degrees(angle)
166
+ return angle_degrees
167
+ else:
168
+ return None
169
+
170
+ def compute_n_sides(self, data):
171
+ n_left, n_right = 0, 0
172
+ for line in data:
173
+ line_words = line.split()[:3]
174
+ if 'left' in line_words:
175
+ n_left += 1
176
+ elif 'right' in line_words:
177
+ n_right += 1
178
+ return n_left, n_right
179
+
180
+ def __repr__(self) -> str:
181
+ return f"{self.__class__.__name__}(v_th={self.v_th}, h_th={self.h_th})"
182
+
183
+
184
+ class AddGaussianNoise(torch.nn.Module):
185
+ def __init__(self, mean=0., std=2.):
186
+ self.std = std
187
+ self.mean = mean
188
+
189
+ def __call__(self, sample):
190
+ image = sample['image']
191
+ image += torch.randn(image.size()) * self.std + self.mean
192
+ image = torch.clip(image, 0, 1)
193
+
194
+ return {'image': image,
195
+ 'data': sample['data']}
196
+
197
+ def __repr__(self):
198
+ return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
199
+
200
+
201
+ class ColorJitter(torch.nn.Module):
202
+
203
+ def __init__(
204
+ self,
205
+ brightness: Union[float, Tuple[float, float]] = 0,
206
+ contrast: Union[float, Tuple[float, float]] = 0,
207
+ saturation: Union[float, Tuple[float, float]] = 0,
208
+ hue: Union[float, Tuple[float, float]] = 0,
209
+ ) -> None:
210
+ super().__init__()
211
+ self.brightness = self._check_input(brightness, "brightness")
212
+ self.contrast = self._check_input(contrast, "contrast")
213
+ self.saturation = self._check_input(saturation, "saturation")
214
+ self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
215
+
216
+ @torch.jit.unused
217
+ def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
218
+ if isinstance(value, numbers.Number):
219
+ if value < 0:
220
+ raise ValueError(f"If {name} is a single number, it must be non negative.")
221
+ value = [center - float(value), center + float(value)]
222
+ if clip_first_on_zero:
223
+ value[0] = max(value[0], 0.0)
224
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
225
+ value = [float(value[0]), float(value[1])]
226
+ else:
227
+ raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
228
+
229
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
230
+ raise ValueError(f"{name} values should be between {bound}, but got {value}.")
231
+
232
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
233
+ # or (0., 0.) for hue, do nothing
234
+ if value[0] == value[1] == center:
235
+ return None
236
+ else:
237
+ return tuple(value)
238
+
239
+ @staticmethod
240
+ def get_params(
241
+ brightness: Optional[List[float]],
242
+ contrast: Optional[List[float]],
243
+ saturation: Optional[List[float]],
244
+ hue: Optional[List[float]],
245
+ ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
246
+ """Get the parameters for the randomized transform to be applied on image.
247
+
248
+ Args:
249
+ brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
250
+ uniformly. Pass None to turn off the transformation.
251
+ contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
252
+ uniformly. Pass None to turn off the transformation.
253
+ saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
254
+ uniformly. Pass None to turn off the transformation.
255
+ hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
256
+ Pass None to turn off the transformation.
257
+
258
+ Returns:
259
+ tuple: The parameters used to apply the randomized transform
260
+ along with their random order.
261
+ """
262
+ fn_idx = torch.randperm(4)
263
+
264
+ b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
265
+ c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
266
+ s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
267
+ h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
268
+
269
+ return fn_idx, b, c, s, h
270
+
271
+
272
+ def forward(self, sample):
273
+ """
274
+ Args:
275
+ img (PIL Image or Tensor): Input image.
276
+
277
+ Returns:
278
+ PIL Image or Tensor: Color jittered image.
279
+ """
280
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
281
+ self.brightness, self.contrast, self.saturation, self.hue
282
+ )
283
+
284
+ image = sample['image']
285
+
286
+ for fn_id in fn_idx:
287
+ if fn_id == 0 and brightness_factor is not None:
288
+ image = f.adjust_brightness(image, brightness_factor)
289
+ elif fn_id == 1 and contrast_factor is not None:
290
+ image = f.adjust_contrast(image, contrast_factor)
291
+ elif fn_id == 2 and saturation_factor is not None:
292
+ image = f.adjust_saturation(image, saturation_factor)
293
+ elif fn_id == 3 and hue_factor is not None:
294
+ image = f.adjust_hue(image, hue_factor)
295
+
296
+ return {'image': image,
297
+ 'data': sample['data']}
298
+
299
+
300
+ def __repr__(self) -> str:
301
+ s = (
302
+ f"{self.__class__.__name__}("
303
+ f"brightness={self.brightness}"
304
+ f", contrast={self.contrast}"
305
+ f", saturation={self.saturation}"
306
+ f", hue={self.hue})"
307
+ )
308
+ return s
309
+
310
+
311
+ class Resize(torch.nn.Module):
312
+ def __init__(self, size, interpolation=InterpolationMode.BILINEAR):
313
+ super().__init__()
314
+ self.size = size
315
+
316
+ # Backward compatibility with integer value
317
+ if isinstance(interpolation, int):
318
+ warnings.warn(
319
+ "Argument interpolation should be of type InterpolationMode instead of int. "
320
+ "Please, use InterpolationMode enum."
321
+ )
322
+ interpolation = _interpolation_modes_from_int(interpolation)
323
+
324
+ self.interpolation = interpolation
325
+
326
+ def forward(self, sample):
327
+ image = sample["image"]
328
+ image = f.resize(image, self.size, self.interpolation)
329
+
330
+ return {'image': image,
331
+ 'data': sample['data']}
332
+
333
+
334
+ def __repr__(self):
335
+ interpolate_str = self.interpolation.value
336
+ return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
337
+
338
+
339
+
340
+ transforms = v2.Compose([
341
+ ToTensor(),
342
+ RandomHorizontalFlip(p=.5),
343
+ ColorJitter(brightness=(0.05), contrast=(0.05), saturation=(0.05), hue=(0.05)),
344
+ AddGaussianNoise(0, .1)
345
+ ])
346
+
347
+ transforms_w_LR = v2.Compose([
348
+ ToTensor(),
349
+ RandomHorizontalFlip(p=.5),
350
+ LRAmbiguityFix(v_th=70, h_th=20),
351
+ ColorJitter(brightness=(0.05), contrast=(0.05), saturation=(0.05), hue=(0.05)),
352
+ AddGaussianNoise(0, .1)
353
+ ])
354
+
355
+ no_transforms = v2.Compose([
356
+ ToTensor()
357
+ ])
358
+
359
+ no_transforms_w_LR = v2.Compose([
360
+ ToTensor(),
361
+ LRAmbiguityFix(v_th=70, h_th=20)
362
+ ])
363
+
364
+
model/transformsWC.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numbers
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms.functional as f
6
+ import torchvision.transforms as T
7
+ import torchvision.transforms.v2 as v2
8
+
9
+ from torchvision import transforms as _transforms
10
+ from typing import List, Optional, Tuple, Union
11
+ from scipy import ndimage
12
+ from torch import Tensor
13
+
14
+ from sn_calibration.src.evaluate_extremities import mirror_labels
15
+
16
+ class ToTensor(torch.nn.Module):
17
+ def __call__(self, sample):
18
+ image, target, mask = sample['image'], sample['target'], sample['mask']
19
+
20
+ return {'image': f.to_tensor(image).float(),
21
+ 'target': torch.from_numpy(target).float(),
22
+ 'mask': torch.from_numpy(mask).float()}
23
+
24
+ def __repr__(self) -> str:
25
+ return f"{self.__class__.__name__}()"
26
+
27
+ class RandomHorizontalFlip(torch.nn.Module):
28
+ def __init__(self, p=0.5):
29
+ super().__init__()
30
+ self.p = p
31
+ self.swap_dict = {1:3, 2:2, 3:1, 4:7, 5:6, 6:5, 7:4, 8:11, 9:10, 10:9, 11:8, 12:15, 13:14, 14:13, 15:12,
32
+ 16:19, 17:18, 18:17, 19:16, 20:23, 21:22, 22:21, 23:20, 24:27, 25:26, 26:25, 27:24, 28:30,
33
+ 29:29, 30:28, 31:33, 32:32, 33:31, 34:36, 35:35, 36:34, 37:40, 38:39, 39:38, 40:37, 41:44,
34
+ 42:43, 43:42, 44:41, 45:57, 46:56, 47:55, 48:49, 49:48, 50:52, 51:51, 52:50, 53:54, 54:53,
35
+ 55:47, 56:46, 57:45, 58:58}
36
+
37
+
38
+ def forward(self, sample):
39
+ if torch.rand(1) < self.p:
40
+ image, target, mask = sample['image'], sample['target'], sample['mask']
41
+ image = f.hflip(image)
42
+ target = f.hflip(target)
43
+
44
+ target_swap, mask_swap = self.swap_layers(target, mask)
45
+
46
+ return {'image': image,
47
+ 'target': target_swap,
48
+ 'mask': mask_swap}
49
+ else:
50
+ return {'image': sample['image'],
51
+ 'target': sample['target'],
52
+ 'mask': sample['mask']}
53
+
54
+
55
+ def swap_layers(self, target, mask):
56
+ target_swap = torch.zeros_like(target)
57
+ mask_swap = torch.zeros_like(mask)
58
+ for kp in self.swap_dict.keys():
59
+ kp_swap = self.swap_dict[kp]
60
+ target_swap[kp_swap-1, :, :] = target[kp-1, :, :].clone()
61
+ mask_swap[kp_swap-1] = mask[kp-1].clone()
62
+
63
+ return target_swap, mask_swap
64
+
65
+
66
+ def __repr__(self) -> str:
67
+ return f"{self.__class__.__name__}(p={self.p})"
68
+
69
+
70
+ class AddGaussianNoise(torch.nn.Module):
71
+ def __init__(self, mean=0., std=2.):
72
+ self.std = std
73
+ self.mean = mean
74
+
75
+ def __call__(self, sample):
76
+ image = sample['image']
77
+ image += torch.randn(image.size()) * self.std + self.mean
78
+ image = torch.clip(image, 0, 1)
79
+
80
+ return {'image': image,
81
+ 'target': sample['target'],
82
+ 'mask': sample['mask']}
83
+
84
+ def __repr__(self):
85
+ return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
86
+
87
+
88
+ class ColorJitter(torch.nn.Module):
89
+
90
+ def __init__(
91
+ self,
92
+ brightness: Union[float, Tuple[float, float]] = 0,
93
+ contrast: Union[float, Tuple[float, float]] = 0,
94
+ saturation: Union[float, Tuple[float, float]] = 0,
95
+ hue: Union[float, Tuple[float, float]] = 0,
96
+ ) -> None:
97
+ super().__init__()
98
+ self.brightness = self._check_input(brightness, "brightness")
99
+ self.contrast = self._check_input(contrast, "contrast")
100
+ self.saturation = self._check_input(saturation, "saturation")
101
+ self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
102
+
103
+ @torch.jit.unused
104
+ def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
105
+ if isinstance(value, numbers.Number):
106
+ if value < 0:
107
+ raise ValueError(f"If {name} is a single number, it must be non negative.")
108
+ value = [center - float(value), center + float(value)]
109
+ if clip_first_on_zero:
110
+ value[0] = max(value[0], 0.0)
111
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
112
+ value = [float(value[0]), float(value[1])]
113
+ else:
114
+ raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
115
+
116
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
117
+ raise ValueError(f"{name} values should be between {bound}, but got {value}.")
118
+
119
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
120
+ # or (0., 0.) for hue, do nothing
121
+ if value[0] == value[1] == center:
122
+ return None
123
+ else:
124
+ return tuple(value)
125
+
126
+ @staticmethod
127
+ def get_params(
128
+ brightness: Optional[List[float]],
129
+ contrast: Optional[List[float]],
130
+ saturation: Optional[List[float]],
131
+ hue: Optional[List[float]],
132
+ ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
133
+ """Get the parameters for the randomized transform to be applied on image.
134
+
135
+ Args:
136
+ brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
137
+ uniformly. Pass None to turn off the transformation.
138
+ contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
139
+ uniformly. Pass None to turn off the transformation.
140
+ saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
141
+ uniformly. Pass None to turn off the transformation.
142
+ hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
143
+ Pass None to turn off the transformation.
144
+
145
+ Returns:
146
+ tuple: The parameters used to apply the randomized transform
147
+ along with their random order.
148
+ """
149
+ fn_idx = torch.randperm(4)
150
+
151
+ b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
152
+ c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
153
+ s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
154
+ h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
155
+
156
+ return fn_idx, b, c, s, h
157
+
158
+
159
+ def forward(self, sample):
160
+ """
161
+ Args:
162
+ img (PIL Image or Tensor): Input image.
163
+
164
+ Returns:
165
+ PIL Image or Tensor: Color jittered image.
166
+ """
167
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
168
+ self.brightness, self.contrast, self.saturation, self.hue
169
+ )
170
+
171
+ image = sample['image']
172
+
173
+ for fn_id in fn_idx:
174
+ if fn_id == 0 and brightness_factor is not None:
175
+ image = f.adjust_brightness(image, brightness_factor)
176
+ elif fn_id == 1 and contrast_factor is not None:
177
+ image = f.adjust_contrast(image, contrast_factor)
178
+ elif fn_id == 2 and saturation_factor is not None:
179
+ image = f.adjust_saturation(image, saturation_factor)
180
+ elif fn_id == 3 and hue_factor is not None:
181
+ image = f.adjust_hue(image, hue_factor)
182
+
183
+ return {'image': image,
184
+ 'target': sample['target'],
185
+ 'mask': sample['mask']}
186
+
187
+
188
+ def __repr__(self) -> str:
189
+ s = (
190
+ f"{self.__class__.__name__}("
191
+ f"brightness={self.brightness}"
192
+ f", contrast={self.contrast}"
193
+ f", saturation={self.saturation}"
194
+ f", hue={self.hue})"
195
+ )
196
+ return s
197
+
198
+
199
+
200
+ transforms = v2.Compose([
201
+ ToTensor(),
202
+ RandomHorizontalFlip(p=.5),
203
+ ColorJitter(brightness=(0.05), contrast=(0.05), saturation=(0.05), hue=(0.05)),
204
+ AddGaussianNoise(0, .1)
205
+ ])
206
+
207
+
208
+ no_transforms = v2.Compose([
209
+ ToTensor(),
210
+ ])
211
+
212
+
model/transformsWC_l.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numbers
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms.functional as f
6
+ import torchvision.transforms as T
7
+ import torchvision.transforms.v2 as v2
8
+
9
+ from torchvision import transforms as _transforms
10
+ from typing import List, Optional, Tuple, Union
11
+ from scipy import ndimage
12
+ from torch import Tensor
13
+
14
+ from sn_calibration.src.evaluate_extremities import mirror_labels
15
+
16
+ class ToTensor(torch.nn.Module):
17
+ def __call__(self, sample):
18
+ image, target, mask = sample['image'], sample['target'], sample['mask']
19
+
20
+ return {'image': f.to_tensor(image).float(),
21
+ 'target': torch.from_numpy(target).float(),
22
+ 'mask': torch.from_numpy(mask).float()}
23
+
24
+ def __repr__(self) -> str:
25
+ return f"{self.__class__.__name__}()"
26
+
27
+ class RandomHorizontalFlip(torch.nn.Module):
28
+ def __init__(self, p=0.5):
29
+ super().__init__()
30
+ self.p = p
31
+ self.swap_dict = {1: 4, 2: 5, 3: 6, 4: 1, 5: 2, 6: 3, 7: 10, 8: 12, 9: 11, 10: 7, 11: 9, 12: 8, 13: 13,
32
+ 14: 14, 15: 16, 16: 15, 17: 17, 18: 21, 19: 22, 20: 23, 21: 18, 22: 19, 23: 20, 24:24}
33
+
34
+ def forward(self, sample):
35
+ if torch.rand(1) < self.p:
36
+ image, target, mask = sample['image'], sample['target'], sample['mask']
37
+ image = f.hflip(image)
38
+ target = f.hflip(target)
39
+
40
+ target_swap, mask_swap = self.swap_layers(target, mask)
41
+
42
+ return {'image': image,
43
+ 'target': target_swap,
44
+ 'mask': mask_swap}
45
+ else:
46
+ return {'image': sample['image'],
47
+ 'target': sample['target'],
48
+ 'mask': sample['mask']}
49
+
50
+
51
+ def swap_layers(self, target, mask):
52
+ target_swap = torch.zeros_like(target)
53
+ mask_swap = torch.zeros_like(mask)
54
+ for kp in self.swap_dict.keys():
55
+ kp_swap = self.swap_dict[kp]
56
+ target_swap[kp_swap-1, :, :] = target[kp-1, :, :].clone()
57
+ mask_swap[kp_swap-1] = mask[kp-1].clone()
58
+
59
+ return target_swap, mask_swap
60
+
61
+
62
+ def __repr__(self) -> str:
63
+ return f"{self.__class__.__name__}(p={self.p})"
64
+
65
+
66
+ class AddGaussianNoise(torch.nn.Module):
67
+ def __init__(self, mean=0., std=2.):
68
+ self.std = std
69
+ self.mean = mean
70
+
71
+ def __call__(self, sample):
72
+ image = sample['image']
73
+ image += torch.randn(image.size()) * self.std + self.mean
74
+ image = torch.clip(image, 0, 1)
75
+
76
+ return {'image': image,
77
+ 'target': sample['target'],
78
+ 'mask': sample['mask']}
79
+
80
+ def __repr__(self):
81
+ return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
82
+
83
+
84
+ class ColorJitter(torch.nn.Module):
85
+
86
+ def __init__(
87
+ self,
88
+ brightness: Union[float, Tuple[float, float]] = 0,
89
+ contrast: Union[float, Tuple[float, float]] = 0,
90
+ saturation: Union[float, Tuple[float, float]] = 0,
91
+ hue: Union[float, Tuple[float, float]] = 0,
92
+ ) -> None:
93
+ super().__init__()
94
+ self.brightness = self._check_input(brightness, "brightness")
95
+ self.contrast = self._check_input(contrast, "contrast")
96
+ self.saturation = self._check_input(saturation, "saturation")
97
+ self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
98
+
99
+ @torch.jit.unused
100
+ def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
101
+ if isinstance(value, numbers.Number):
102
+ if value < 0:
103
+ raise ValueError(f"If {name} is a single number, it must be non negative.")
104
+ value = [center - float(value), center + float(value)]
105
+ if clip_first_on_zero:
106
+ value[0] = max(value[0], 0.0)
107
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
108
+ value = [float(value[0]), float(value[1])]
109
+ else:
110
+ raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
111
+
112
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
113
+ raise ValueError(f"{name} values should be between {bound}, but got {value}.")
114
+
115
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
116
+ # or (0., 0.) for hue, do nothing
117
+ if value[0] == value[1] == center:
118
+ return None
119
+ else:
120
+ return tuple(value)
121
+
122
+ @staticmethod
123
+ def get_params(
124
+ brightness: Optional[List[float]],
125
+ contrast: Optional[List[float]],
126
+ saturation: Optional[List[float]],
127
+ hue: Optional[List[float]],
128
+ ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
129
+ """Get the parameters for the randomized transform to be applied on image.
130
+
131
+ Args:
132
+ brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
133
+ uniformly. Pass None to turn off the transformation.
134
+ contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
135
+ uniformly. Pass None to turn off the transformation.
136
+ saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
137
+ uniformly. Pass None to turn off the transformation.
138
+ hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
139
+ Pass None to turn off the transformation.
140
+
141
+ Returns:
142
+ tuple: The parameters used to apply the randomized transform
143
+ along with their random order.
144
+ """
145
+ fn_idx = torch.randperm(4)
146
+
147
+ b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
148
+ c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
149
+ s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
150
+ h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
151
+
152
+ return fn_idx, b, c, s, h
153
+
154
+
155
+ def forward(self, sample):
156
+ """
157
+ Args:
158
+ img (PIL Image or Tensor): Input image.
159
+
160
+ Returns:
161
+ PIL Image or Tensor: Color jittered image.
162
+ """
163
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
164
+ self.brightness, self.contrast, self.saturation, self.hue
165
+ )
166
+
167
+ image = sample['image']
168
+
169
+ for fn_id in fn_idx:
170
+ if fn_id == 0 and brightness_factor is not None:
171
+ image = f.adjust_brightness(image, brightness_factor)
172
+ elif fn_id == 1 and contrast_factor is not None:
173
+ image = f.adjust_contrast(image, contrast_factor)
174
+ elif fn_id == 2 and saturation_factor is not None:
175
+ image = f.adjust_saturation(image, saturation_factor)
176
+ elif fn_id == 3 and hue_factor is not None:
177
+ image = f.adjust_hue(image, hue_factor)
178
+
179
+ return {'image': image,
180
+ 'target': sample['target'],
181
+ 'mask': sample['mask']}
182
+
183
+
184
+ def __repr__(self) -> str:
185
+ s = (
186
+ f"{self.__class__.__name__}("
187
+ f"brightness={self.brightness}"
188
+ f", contrast={self.contrast}"
189
+ f", saturation={self.saturation}"
190
+ f", hue={self.hue})"
191
+ )
192
+ return s
193
+
194
+
195
+
196
+ transforms = v2.Compose([
197
+ ToTensor(),
198
+ RandomHorizontalFlip(p=.5),
199
+ ColorJitter(brightness=(0.05), contrast=(0.05), saturation=(0.05), hue=(0.05)),
200
+ AddGaussianNoise(0, .1)
201
+ ])
202
+
203
+
204
+ no_transforms = v2.Compose([
205
+ ToTensor(),
206
+ ])
207
+
208
+
model/transforms_l.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import numbers
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+ import torchvision.transforms.functional as f
7
+ import torchvision.transforms.v2 as v2
8
+
9
+ from torchvision.transforms.functional import _interpolation_modes_from_int, InterpolationMode
10
+ from torchvision import transforms as _transforms
11
+ from typing import List, Optional, Tuple, Union
12
+ from scipy import ndimage
13
+ from torch import Tensor
14
+
15
+ from sn_calibration.src.evaluate_extremities import mirror_labels
16
+
17
+
18
+
19
+ class ToTensor(torch.nn.Module):
20
+ def __call__(self, sample):
21
+ image, target = sample['image'], sample['data']
22
+
23
+
24
+ return {'image': f.to_tensor(image).float(),
25
+ 'data': sample['data']}
26
+
27
+ def __repr__(self) -> str:
28
+ return f"{self.__class__.__name__}()"
29
+
30
+
31
+ class Normalize(torch.nn.Module):
32
+ def __init__(self, mean, std):
33
+ super().__init__()
34
+ self.mean = mean
35
+ self.std = std
36
+
37
+ def forward(self, sample):
38
+ image = sample['image']
39
+ image = f.normalize(image, self.mean, self.std)
40
+
41
+ return {'image': image,
42
+ 'data': sample['data']}
43
+
44
+
45
+ def __repr__(self) -> str:
46
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
47
+
48
+
49
+ FLIP_POSTS = {
50
+ 'Goal left post right': 'Goal left post left ',
51
+ 'Goal left post left ': 'Goal left post right',
52
+ 'Goal right post right': 'Goal right post left',
53
+ 'Goal right post left': 'Goal right post right'
54
+ }
55
+
56
+ h_lines = ['Goal left crossbar', 'Side line left', 'Small rect. left main', 'Big rect. left main', 'Middle line',
57
+ 'Big rect. right main', 'Small rect. right main', 'Side line right', 'Goal right crossbar']
58
+ v_lines = ['Side line top', 'Big rect. left top', 'Small rect. left top', 'Small rect. left bottom',
59
+ 'Big rect. left bottom', 'Big rect. right top', 'Small rect. right top', 'Small rect. right bottom',
60
+ 'Big rect. right bottom', 'Side line bottom']
61
+
62
+ def swap_top_bottom_names(line_name: str) -> str:
63
+ x: str = 'top'
64
+ y: str = 'bottom'
65
+ if x in line_name or y in line_name:
66
+ return y.join(part.replace(y, x) for part in line_name.split(x))
67
+ return line_name
68
+
69
+
70
+ def swap_posts_names(line_name: str) -> str:
71
+ if line_name in FLIP_POSTS:
72
+ return FLIP_POSTS[line_name]
73
+ return line_name
74
+
75
+
76
+ def flip_annot_names(annot, swap_top_bottom: bool = True, swap_posts: bool = True):
77
+ annot = mirror_labels(annot)
78
+ if swap_top_bottom:
79
+ annot = {swap_top_bottom_names(k): v for k, v in annot.items()}
80
+ if swap_posts:
81
+ annot = {swap_posts_names(k): v for k, v in annot.items()}
82
+ return annot
83
+
84
+
85
+ class LRAmbiguityFix(torch.nn.Module):
86
+ def __init__(self, v_th, h_th):
87
+ super().__init__()
88
+ self.v_th = v_th
89
+ self.h_th = h_th
90
+
91
+ def forward(self, sample):
92
+ data = sample['data']
93
+
94
+ if len(data) == 0:
95
+ return {'image': sample['image'],
96
+ 'data': sample['data']}
97
+
98
+ n_left, n_right = self.compute_n_sides(data)
99
+
100
+ angles_v, angles_h = [], []
101
+ for line in data.keys():
102
+ line_points = []
103
+ for point in data[line]:
104
+ line_points.append((point['x'], point['y']))
105
+
106
+ sorted_points = sorted(line_points, key=lambda point: (point[0], point[1]))
107
+ pi, pf = sorted_points[0], sorted_points[-1]
108
+ if line in h_lines:
109
+ angle_h = self.calculate_angle_h(pi[0], pi[1], pf[0], pf[1])
110
+ if angle_h:
111
+ angles_h.append(abs(angle_h))
112
+ if line in v_lines:
113
+ angle_v = self.calculate_angle_v(pi[0], pi[1], pf[0], pf[1])
114
+ if angle_v:
115
+ angles_v.append(abs(angle_v))
116
+
117
+
118
+ if len(angles_h) > 0 and len(angles_v) > 0:
119
+ if np.mean(angles_h) < self.h_th and np.mean(angles_v) < self.v_th:
120
+ if n_right > n_left:
121
+ data = flip_annot_names(data, swap_top_bottom=False, swap_posts=False)
122
+
123
+ return {'image': sample['image'],
124
+ 'data': data}
125
+
126
+ def calculate_angle_h(self, x1, y1, x2, y2):
127
+ if not x2 - x1 == 0:
128
+ slope = (y2 - y1) / (x2 - x1)
129
+ angle = math.atan(slope)
130
+ angle_degrees = math.degrees(angle)
131
+ return angle_degrees
132
+ else:
133
+ return None
134
+ def calculate_angle_v(self, x1, y1, x2, y2):
135
+ if not x2 - x1 == 0:
136
+ slope = (y2 - y1) / (x2 - x1)
137
+ angle = math.atan(1 / slope) if slope != 0 else math.pi / 2 # Avoid division by zero
138
+ angle_degrees = math.degrees(angle)
139
+ return angle_degrees
140
+ else:
141
+ return None
142
+
143
+ def compute_n_sides(self, data):
144
+ n_left, n_right = 0, 0
145
+ for line in data:
146
+ line_words = line.split()[:3]
147
+ if 'left' in line_words:
148
+ n_left += 1
149
+ elif 'right' in line_words:
150
+ n_right += 1
151
+ return n_left, n_right
152
+
153
+ def __repr__(self) -> str:
154
+ return f"{self.__class__.__name__}(v_th={self.v_th}, h_th={self.h_th})"
155
+
156
+
157
+ class RandomHorizontalFlip(torch.nn.Module):
158
+ def __init__(self, p=0.5):
159
+ super().__init__()
160
+ self.p = p
161
+
162
+ def forward(self, sample):
163
+ if torch.rand(1) < self.p:
164
+ image, data = sample['image'], sample['data']
165
+ image = f.hflip(image)
166
+ data = flip_annot_names(data)
167
+ for line in data:
168
+ for point in data[line]:
169
+ point['x'] = 1.0 - point['x']
170
+
171
+ return {'image': image,
172
+ 'data': data}
173
+ else:
174
+ return {'image': sample['image'],
175
+ 'data': sample['data']}
176
+
177
+ def __repr__(self) -> str:
178
+ return f"{self.__class__.__name__}(p={self.p})"
179
+
180
+
181
+ class AddGaussianNoise(torch.nn.Module):
182
+ def __init__(self, mean=0., std=2.):
183
+ self.std = std
184
+ self.mean = mean
185
+
186
+ def __call__(self, sample):
187
+ image = sample['image']
188
+ image += torch.randn(image.size()) * self.std + self.mean
189
+ image = torch.clip(image, 0, 1)
190
+
191
+ return {'image': image,
192
+ 'data': sample['data']}
193
+
194
+ def __repr__(self):
195
+ return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
196
+
197
+
198
+ class ColorJitter(torch.nn.Module):
199
+
200
+ def __init__(
201
+ self,
202
+ brightness: Union[float, Tuple[float, float]] = 0,
203
+ contrast: Union[float, Tuple[float, float]] = 0,
204
+ saturation: Union[float, Tuple[float, float]] = 0,
205
+ hue: Union[float, Tuple[float, float]] = 0,
206
+ ) -> None:
207
+ super().__init__()
208
+ self.brightness = self._check_input(brightness, "brightness")
209
+ self.contrast = self._check_input(contrast, "contrast")
210
+ self.saturation = self._check_input(saturation, "saturation")
211
+ self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
212
+
213
+ @torch.jit.unused
214
+ def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
215
+ if isinstance(value, numbers.Number):
216
+ if value < 0:
217
+ raise ValueError(f"If {name} is a single number, it must be non negative.")
218
+ value = [center - float(value), center + float(value)]
219
+ if clip_first_on_zero:
220
+ value[0] = max(value[0], 0.0)
221
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
222
+ value = [float(value[0]), float(value[1])]
223
+ else:
224
+ raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
225
+
226
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
227
+ raise ValueError(f"{name} values should be between {bound}, but got {value}.")
228
+
229
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
230
+ # or (0., 0.) for hue, do nothing
231
+ if value[0] == value[1] == center:
232
+ return None
233
+ else:
234
+ return tuple(value)
235
+
236
+ @staticmethod
237
+ def get_params(
238
+ brightness: Optional[List[float]],
239
+ contrast: Optional[List[float]],
240
+ saturation: Optional[List[float]],
241
+ hue: Optional[List[float]],
242
+ ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
243
+ """Get the parameters for the randomized transform to be applied on image.
244
+
245
+ Args:
246
+ brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
247
+ uniformly. Pass None to turn off the transformation.
248
+ contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
249
+ uniformly. Pass None to turn off the transformation.
250
+ saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
251
+ uniformly. Pass None to turn off the transformation.
252
+ hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
253
+ Pass None to turn off the transformation.
254
+
255
+ Returns:
256
+ tuple: The parameters used to apply the randomized transform
257
+ along with their random order.
258
+ """
259
+ fn_idx = torch.randperm(4)
260
+
261
+ b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
262
+ c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
263
+ s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
264
+ h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
265
+
266
+ return fn_idx, b, c, s, h
267
+
268
+
269
+ def forward(self, sample):
270
+ """
271
+ Args:
272
+ img (PIL Image or Tensor): Input image.
273
+
274
+ Returns:
275
+ PIL Image or Tensor: Color jittered image.
276
+ """
277
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
278
+ self.brightness, self.contrast, self.saturation, self.hue
279
+ )
280
+
281
+ image = sample['image']
282
+
283
+ for fn_id in fn_idx:
284
+ if fn_id == 0 and brightness_factor is not None:
285
+ image = f.adjust_brightness(image, brightness_factor)
286
+ elif fn_id == 1 and contrast_factor is not None:
287
+ image = f.adjust_contrast(image, contrast_factor)
288
+ elif fn_id == 2 and saturation_factor is not None:
289
+ image = f.adjust_saturation(image, saturation_factor)
290
+ elif fn_id == 3 and hue_factor is not None:
291
+ image = f.adjust_hue(image, hue_factor)
292
+
293
+ return {'image': image,
294
+ 'data': sample['data']}
295
+
296
+
297
+ def __repr__(self) -> str:
298
+ s = (
299
+ f"{self.__class__.__name__}("
300
+ f"brightness={self.brightness}"
301
+ f", contrast={self.contrast}"
302
+ f", saturation={self.saturation}"
303
+ f", hue={self.hue})"
304
+ )
305
+ return s
306
+
307
+
308
+ class Resize(torch.nn.Module):
309
+ def __init__(self, size, interpolation=InterpolationMode.BILINEAR):
310
+ super().__init__()
311
+ self.size = size
312
+
313
+ # Backward compatibility with integer value
314
+ if isinstance(interpolation, int):
315
+ warnings.warn(
316
+ "Argument interpolation should be of type InterpolationMode instead of int. "
317
+ "Please, use InterpolationMode enum."
318
+ )
319
+ interpolation = _interpolation_modes_from_int(interpolation)
320
+
321
+ self.interpolation = interpolation
322
+
323
+ def forward(self, sample):
324
+ image = sample["image"]
325
+ image = f.resize(image, self.size, self.interpolation)
326
+
327
+ return {'image': image,
328
+ 'data': sample['data']}
329
+
330
+
331
+ def __repr__(self):
332
+ interpolate_str = self.interpolation.value
333
+ return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
334
+
335
+
336
+
337
+
338
+ transforms = v2.Compose([
339
+ ToTensor(),
340
+ RandomHorizontalFlip(p=.5),
341
+ ColorJitter(brightness=(0.05), contrast=(0.05), saturation=(0.05), hue=(0.05)),
342
+ AddGaussianNoise(0, .1)
343
+ ])
344
+
345
+ transforms_w_LR = v2.Compose([
346
+ ToTensor(),
347
+ RandomHorizontalFlip(p=.5),
348
+ LRAmbiguityFix(v_th=70, h_th=20),
349
+ ColorJitter(brightness=(0.05), contrast=(0.05), saturation=(0.05), hue=(0.05)),
350
+ AddGaussianNoise(0, .1)
351
+ ])
352
+
353
+ no_transforms = v2.Compose([
354
+ ToTensor(),
355
+ ])
356
+
357
+ no_transforms_w_LR = v2.Compose([
358
+ ToTensor(),
359
+ LRAmbiguityFix(v_th=70, h_th=20)
360
+ ])
requirements.txt ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # API Framework
2
+ # fastapi==0.104.1
3
+ # uvicorn[standard]==0.24.0
4
+ # python-multipart==0.0.6
5
+ # pydantic==2.5.0
6
+
7
+ # # Core dependencies (compatibles Python 3.10)
8
+ # numpy==1.24.3
9
+ # opencv-python-headless==4.8.1.78
10
+ # pillow==10.1.0
11
+ # scipy==1.11.4
12
+ # PyYAML==6.0.1
13
+ # lsq-ellipse==2.2.1
14
+ # shapely==2.0.2
15
+
16
+ # # PyTorch CPU (compatible Python 3.10)
17
+ # torch==2.1.0
18
+ # torchvision==0.16.0
19
+
20
+ # # Utilities
21
+ # tqdm==4.66.1
22
+
23
+ # requirements_clean.txt
24
+ fastapi
25
+ uvicorn[standard]
26
+ python-multipart
27
+ pydantic
28
+
29
+ # Core dependencies
30
+ numpy
31
+ opencv-python-headless
32
+ pillow
33
+ scipy
34
+ PyYAML
35
+ lsq-ellipse
36
+ shapely
37
+
38
+ # PyTorch CPU uniquement (beaucoup plus léger)
39
+ --extra-index-url https://download.pytorch.org/whl/cpu
40
+ torch+cpu
41
+ torchvision+cpu
42
+
43
+ # Utilities
44
+ tqdm
run_api.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+
3
+ if __name__ == "__main__":
4
+ print("🚀 Démarrage de l'API Football Vision Calibration...")
5
+ print("📍 API accessible sur: http://localhost:8000")
6
+ print("📖 Documentation: http://localhost:8000/docs")
7
+
8
+ uvicorn.run(
9
+ "api:app", # ← Changement ici : string au lieu d'objet
10
+ host="0.0.0.0",
11
+ port=8000,
12
+ reload=True # Rechargement automatique en développement
13
+ )
scripts/eval_tswc.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import glob
5
+ import torch
6
+ import zipfile
7
+ import argparse
8
+ import numpy as np
9
+
10
+ from tqdm import tqdm
11
+
12
+ sys.path.insert(1, os.path.join(sys.path[0], '..'))
13
+ from utils.utils_calib import FramebyFrameCalib
14
+ from model.metrics import calc_iou_part, calc_iou_whole_with_poly, calc_reproj_error, calc_proj_error
15
+
16
+ def parse_args():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--root_dir", type=str, required=True)
19
+ parser.add_argument("--split", type=str, required=True)
20
+ parser.add_argument("--pred_file", type=str, required=True)
21
+
22
+ args = parser.parse_args()
23
+ return args
24
+
25
+
26
+ def get_homographies(file_paths):
27
+ npy_files = []
28
+ for file_path in file_paths:
29
+ directory_path = os.path.join(os.path.join(args.root_dir, "Annotations/80_95"), file_path)
30
+ if os.path.exists(directory_path):
31
+ files = os.listdir(directory_path)
32
+ npy_files.extend([os.path.join(directory_path, file) for file in files if file.endswith('.npy')])
33
+
34
+ npy_files = sorted(npy_files)
35
+ return npy_files
36
+
37
+
38
+ def make_file_name(file):
39
+ file = "TS-WorldCup/" + file.split("TS-WorldCup/")[-1]
40
+ splits = file.split('/')
41
+ side = splits[3]
42
+ match = splits[4]
43
+ image = splits[5]
44
+ frame = image.split('_homography.npy')[0]
45
+ return side + '-' + match + '-' + frame + '.json'
46
+
47
+
48
+ def pan_tilt_roll_to_orientation(pan, tilt, roll):
49
+ """
50
+ Conversion from euler angles to orientation matrix.
51
+ :param pan:
52
+ :param tilt:
53
+ :param roll:
54
+ :return: orientation matrix
55
+ """
56
+ Rpan = np.array([
57
+ [np.cos(pan), -np.sin(pan), 0],
58
+ [np.sin(pan), np.cos(pan), 0],
59
+ [0, 0, 1]])
60
+ Rroll = np.array([
61
+ [np.cos(roll), -np.sin(roll), 0],
62
+ [np.sin(roll), np.cos(roll), 0],
63
+ [0, 0, 1]])
64
+ Rtilt = np.array([
65
+ [1, 0, 0],
66
+ [0, np.cos(tilt), -np.sin(tilt)],
67
+ [0, np.sin(tilt), np.cos(tilt)]])
68
+ rotMat = np.dot(Rpan, np.dot(Rtilt, Rroll))
69
+ return rotMat
70
+
71
+ def get_sn_homography(cam_params: dict, batch_size=1):
72
+ # Extract relevant camera parameters from the dictionary
73
+ pan_degrees = cam_params['cam_params']['pan_degrees']
74
+ tilt_degrees = cam_params['cam_params']['tilt_degrees']
75
+ roll_degrees = cam_params['cam_params']['roll_degrees']
76
+ x_focal_length = cam_params['cam_params']['x_focal_length']
77
+ y_focal_length = cam_params['cam_params']['y_focal_length']
78
+ principal_point = np.array(cam_params['cam_params']['principal_point'])
79
+ position_meters = np.array(cam_params['cam_params']['position_meters'])
80
+
81
+ pan = pan_degrees * np.pi / 180.
82
+ tilt = tilt_degrees * np.pi / 180.
83
+ roll = roll_degrees * np.pi / 180.
84
+
85
+ rotation = np.array([[-np.sin(pan) * np.sin(roll) * np.cos(tilt) + np.cos(pan) * np.cos(roll),
86
+ np.sin(pan) * np.cos(roll) + np.sin(roll) * np.cos(pan) * np.cos(tilt), np.sin(roll) * np.sin(tilt)],
87
+ [-np.sin(pan) * np.cos(roll) * np.cos(tilt) - np.sin(roll) * np.cos(pan),
88
+ -np.sin(pan) * np.sin(roll) + np.cos(pan) * np.cos(roll) * np.cos(tilt), np.sin(tilt) * np.cos(roll)],
89
+ [np.sin(pan) * np.sin(tilt), -np.sin(tilt) * np.cos(pan), np.cos(tilt)]], dtype='float')
90
+
91
+ rotation = np.transpose(pan_tilt_roll_to_orientation(pan, tilt, roll))
92
+
93
+ def convert_homography_SN_to_WC14(H):
94
+ T = np.eye(3)
95
+ T[0, -1] = 105 / 2
96
+ T[1, -1] = 68 / 2
97
+ meter2yard = 1.09361
98
+ S = np.eye(3)
99
+ S[0, 0] = meter2yard
100
+ S[1, 1] = meter2yard
101
+ H_SN = S @ (T @ H)
102
+ return H_SN
103
+
104
+ def get_homography_by_index(homography_file):
105
+ homography = np.load(homography_file)
106
+ homography = homography / homography[2:3, 2:3]
107
+ return homography
108
+
109
+
110
+ if __name__ == "__main__":
111
+ args = parse_args()
112
+
113
+ missed = 0
114
+ iou_part_list, iou_whole_list = [], []
115
+ rep_err_list, proj_err_list = [], []
116
+
117
+ with open(args.root_dir + args.split + '.txt', 'r') as file:
118
+ # Read lines from the file and remove trailing newline characters
119
+ seqs = [line.strip() for line in file.readlines()]
120
+
121
+ homographies = get_homographies(seqs)
122
+ prediction_archive = zipfile.ZipFile(args.pred_file, 'r')
123
+ cam = FramebyFrameCalib(1280, 720, denormalize=True)
124
+
125
+ for h_gt in tqdm(homographies):
126
+ file_name = h_gt.split('/')[-1].split('.')[0]
127
+ pred_name = make_file_name(h_gt)
128
+
129
+ if pred_name not in prediction_archive.namelist():
130
+ missed += 1
131
+ continue
132
+
133
+ homography_gt = get_homography_by_index(h_gt)
134
+ final_dict = prediction_archive.read(pred_name)
135
+ final_dict = json.loads(final_dict.decode('utf-8'))
136
+ keypoints_dict = final_dict['kp']
137
+ lines_dict = final_dict['lines']
138
+ keypoints_dict = {int(key): value for key, value in keypoints_dict.items()}
139
+ lines_dict = {int(key): value for key, value in lines_dict.items()}
140
+
141
+ cam.update(keypoints_dict, lines_dict)
142
+ final_dict = cam.heuristic_voting_ground(refine_lines=True)
143
+ #homography_pred, ret = cam.get_homography_from_ground_plane(use_ransac=20, inverse=True, refine_lines=True)
144
+ if final_dict is None:
145
+ #if homography_pred is None:
146
+ missed += 1
147
+ continue
148
+ homography_pred = final_dict["homography"]
149
+ homography_pred = convert_homography_SN_to_WC14(homography_pred)
150
+
151
+ iou_p = calc_iou_part(homography_pred, homography_gt)
152
+ iou_w, _, _ = calc_iou_whole_with_poly(homography_pred, homography_gt)
153
+ rep_err = calc_reproj_error(homography_pred, homography_gt)
154
+ proj_err = calc_proj_error(homography_pred, homography_gt)
155
+
156
+ iou_part_list.append(iou_p)
157
+ iou_whole_list.append(iou_w)
158
+ rep_err_list.append(rep_err)
159
+ proj_err_list.append(proj_err)
160
+
161
+
162
+ print(f'Completeness: {1-missed/len(homographies)}')
163
+ print('IOU Part')
164
+ print(f'mean: {np.mean(iou_part_list)} \t median: {np.median(iou_part_list)}')
165
+ print('\nIOU Whole')
166
+ print(f'mean: {np.mean(iou_whole_list)} \t median: {np.median(iou_whole_list)}')
167
+ print('\nReprojection Err.')
168
+ print(f'mean: {np.mean(rep_err_list)} \t median: {np.median(rep_err_list)}')
169
+ print('\nProjection Err. (meters)')
170
+ print(f'mean: {np.mean(proj_err_list) * 0.9144} \t median: {np.median(proj_err_list) * 0.9144}')
171
+
172
+
173
+
174
+
175
+
scripts/eval_wc14.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import glob
5
+
6
+ import torch
7
+ import zipfile
8
+ import argparse
9
+ import numpy as np
10
+
11
+ from tqdm import tqdm
12
+ from typing import final
13
+
14
+ sys.path.insert(1, os.path.join(sys.path[0], '..'))
15
+ from utils.utils_calib import FramebyFrameCalib
16
+ from model.metrics import calc_iou_part, calc_iou_whole_with_poly, calc_reproj_error, calc_proj_error
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--root_dir", type=str, required=True)
21
+ parser.add_argument("--split", type=str, required=True)
22
+ parser.add_argument("--pred_file", type=str, required=True)
23
+
24
+ args = parser.parse_args()
25
+ return args
26
+
27
+
28
+ def pan_tilt_roll_to_orientation(pan, tilt, roll):
29
+ """
30
+ Conversion from euler angles to orientation matrix.
31
+ :param pan:
32
+ :param tilt:
33
+ :param roll:
34
+ :return: orientation matrix
35
+ """
36
+ Rpan = np.array([
37
+ [np.cos(pan), -np.sin(pan), 0],
38
+ [np.sin(pan), np.cos(pan), 0],
39
+ [0, 0, 1]])
40
+ Rroll = np.array([
41
+ [np.cos(roll), -np.sin(roll), 0],
42
+ [np.sin(roll), np.cos(roll), 0],
43
+ [0, 0, 1]])
44
+ Rtilt = np.array([
45
+ [1, 0, 0],
46
+ [0, np.cos(tilt), -np.sin(tilt)],
47
+ [0, np.sin(tilt), np.cos(tilt)]])
48
+ rotMat = np.dot(Rpan, np.dot(Rtilt, Rroll))
49
+ return rotMat
50
+
51
+ def get_sn_homography(cam_params: dict, batch_size=1):
52
+ # Extract relevant camera parameters from the dictionary
53
+ pan_degrees = cam_params['cam_params']['pan_degrees']
54
+ tilt_degrees = cam_params['cam_params']['tilt_degrees']
55
+ roll_degrees = cam_params['cam_params']['roll_degrees']
56
+ x_focal_length = cam_params['cam_params']['x_focal_length']
57
+ y_focal_length = cam_params['cam_params']['y_focal_length']
58
+ principal_point = np.array(cam_params['cam_params']['principal_point'])
59
+ position_meters = np.array(cam_params['cam_params']['position_meters'])
60
+
61
+ pan = pan_degrees * np.pi / 180.
62
+ tilt = tilt_degrees * np.pi / 180.
63
+ roll = roll_degrees * np.pi / 180.
64
+
65
+ rotation = np.array([[-np.sin(pan) * np.sin(roll) * np.cos(tilt) + np.cos(pan) * np.cos(roll),
66
+ np.sin(pan) * np.cos(roll) + np.sin(roll) * np.cos(pan) * np.cos(tilt), np.sin(roll) * np.sin(tilt)],
67
+ [-np.sin(pan) * np.cos(roll) * np.cos(tilt) - np.sin(roll) * np.cos(pan),
68
+ -np.sin(pan) * np.sin(roll) + np.cos(pan) * np.cos(roll) * np.cos(tilt), np.sin(tilt) * np.cos(roll)],
69
+ [np.sin(pan) * np.sin(tilt), -np.sin(tilt) * np.cos(pan), np.cos(tilt)]], dtype='float')
70
+
71
+ rotation = np.transpose(pan_tilt_roll_to_orientation(pan, tilt, roll))
72
+
73
+ def convert_homography_SN_to_WC14(H):
74
+ T = np.eye(3)
75
+ T[0, -1] = 105 / 2
76
+ T[1, -1] = 68 / 2
77
+ meter2yard = 1.09361
78
+ S = np.eye(3)
79
+ S[0, 0] = meter2yard
80
+ S[1, 1] = meter2yard
81
+ H_SN = S @ (T @ H)
82
+ return H_SN
83
+
84
+ def get_homography_by_index(homography_file):
85
+ with open(homography_file, 'r') as file:
86
+ lines = file.readlines()
87
+ matrix_elements = []
88
+ for line in lines:
89
+ matrix_elements.extend([float(element) for element in line.split()])
90
+ homography = np.array(matrix_elements).reshape((3, 3))
91
+ homography = homography / homography[2:3, 2:3]
92
+ return homography
93
+
94
+ if __name__ == "__main__":
95
+ args = parse_args()
96
+
97
+ missed = 0
98
+ iou_part_list, iou_whole_list = [], []
99
+ rep_err_list, proj_err_list = [], []
100
+
101
+ homographies = glob.glob(os.path.join(args.root_dir + args.split, "*.homographyMatrix"))
102
+ prediction_archive = zipfile.ZipFile(args.pred_file, 'r')
103
+ cam = FramebyFrameCalib(1280, 720, denormalize=True)
104
+
105
+ for h_gt in tqdm(homographies):
106
+ file_name = h_gt.split('/')[-1].split('.')[0]
107
+ pred_name = file_name + '.json'
108
+
109
+ if pred_name not in prediction_archive.namelist():
110
+ missed += 1
111
+ continue
112
+
113
+ homography_gt = get_homography_by_index(h_gt)
114
+ final_dict = prediction_archive.read(pred_name)
115
+ final_dict = json.loads(final_dict.decode('utf-8'))
116
+ keypoints_dict = final_dict['kp']
117
+ lines_dict = final_dict['lines']
118
+ keypoints_dict = {int(key): value for key, value in keypoints_dict.items()}
119
+ lines_dict = {int(key): value for key, value in lines_dict.items()}
120
+
121
+ cam.update(keypoints_dict, lines_dict)
122
+ final_dict = cam.heuristic_voting_ground(refine_lines=True)
123
+
124
+ if final_dict is None:
125
+ missed += 1
126
+ continue
127
+
128
+ homography_pred = final_dict["homography"]
129
+ homography_pred = convert_homography_SN_to_WC14(homography_pred)
130
+
131
+ iou_p = calc_iou_part(homography_pred, homography_gt)
132
+ iou_w, _, _ = calc_iou_whole_with_poly(homography_pred, homography_gt)
133
+ rep_err = calc_reproj_error(homography_pred, homography_gt)
134
+ proj_err = calc_proj_error(homography_pred, homography_gt)
135
+
136
+ iou_part_list.append(iou_p)
137
+ iou_whole_list.append(iou_w)
138
+ rep_err_list.append(rep_err)
139
+ proj_err_list.append(proj_err)
140
+
141
+
142
+ print(f'Completeness: {1-missed/len(homographies)}')
143
+ print('IOU Part')
144
+ print(f'mean: {np.mean(iou_part_list)} \t median: {np.median(iou_part_list)}')
145
+ print('\nIOU Whole')
146
+ print(f'mean: {np.mean(iou_whole_list)} \t median: {np.median(iou_whole_list)}')
147
+ print('\nReprojection Err.')
148
+ print(f'mean: {np.mean(rep_err_list)} \t median: {np.median(rep_err_list)}')
149
+ print('\nProjection Err. (meters)')
150
+ print(f'mean: {np.mean(proj_err_list) * 0.9144} \t median: {np.median(proj_err_list) * 0.9144}')
151
+
152
+
153
+
154
+
scripts/inference_sn.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import glob
5
+ import yaml
6
+ import torch
7
+ import zipfile
8
+ import argparse
9
+ import warnings
10
+ import numpy as np
11
+ import torchvision.transforms as T
12
+ import torchvision.transforms.functional as f
13
+
14
+ from tqdm import tqdm
15
+ from PIL import Image
16
+
17
+ sys.path.insert(1, os.path.join(sys.path[0], '..'))
18
+
19
+ from model.cls_hrnet import get_cls_net
20
+ from model.cls_hrnet_l import get_cls_net as get_cls_net_l
21
+ from utils.utils_keypoints import KeypointsDB
22
+ from utils.utils_lines import LineKeypointsDB
23
+ from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, \
24
+ coords_to_dict, complete_keypoints
25
+ from utils.utils_calib import FramebyFrameCalib
26
+
27
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
28
+ warnings.filterwarnings("ignore", category=np.RankWarning)
29
+
30
+
31
+ def parse_args():
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument("--cfg", type=str, required=True,
34
+ help="Path to the (kp model) configuration file")
35
+ parser.add_argument("--cfg_l", type=str, required=True,
36
+ help="Path to the (line model) configuration file")
37
+ parser.add_argument("--root_dir", type=str, required=True,
38
+ help="Root directory")
39
+ parser.add_argument("--split", type=str, required=True,
40
+ help="Dataset split")
41
+ parser.add_argument("--save_dir", type=str, required=True,
42
+ help="Saving file path")
43
+ parser.add_argument("--weights_kp", type=str, required=True,
44
+ help="Model (keypoints) weigths to use")
45
+ parser.add_argument("--weights_line", type=str, required=True,
46
+ help="Model (lines) weigths to use")
47
+ parser.add_argument("--cuda", type=str, default="cuda:0",
48
+ help="CUDA device index (default: 'cuda:0')")
49
+ parser.add_argument("--kp_th", type=float, default="0.1")
50
+ parser.add_argument("--line_th", type=float, default="0.1")
51
+ parser.add_argument("--max_reproj_err", type=float, default="50")
52
+ parser.add_argument("--main_cam_only", action='store_true')
53
+ parser.add_argument('--use_gt', action='store_true', help='Use ground truth annotations (default: False)')
54
+
55
+ args = parser.parse_args()
56
+ return args
57
+
58
+
59
+ if __name__ == "__main__":
60
+ args = parse_args()
61
+
62
+ files = glob.glob(os.path.join(args.root_dir + args.split, "*.jpg"))
63
+
64
+ if args.main_cam_only:
65
+ cam_info = json.load(open(args.root_dir + args.split + '/match_info_cam_gt.json'))
66
+ files = [file for file in files if file.split('/')[-1] in cam_info.keys()]
67
+ files = [file for file in files if cam_info[file.split('/')[-1]]['camera'] == 'Main camera center']
68
+ # files = [file for file in files if int(match_info[file.split('/')[-1]]['ms_time']) == \
69
+ # int(match_info[file.split('/')[-1]]['replay_time'])]
70
+
71
+ if args.main_cam_only:
72
+ zip_name = args.save_dir + args.split + '_main.zip'
73
+ else:
74
+ zip_name = args.save_dir + args.split + '.zip'
75
+
76
+ if args.use_gt:
77
+ if args.main_cam_only:
78
+ zip_name_pred = args.save_dir + args.split + '_main_gt.zip'
79
+ else:
80
+ zip_name_pred = args.save_dir + args.split + '_gt.zip'
81
+ else:
82
+ if args.main_cam_only:
83
+ zip_name_pred = args.save_dir + args.split + '_main_pred.zip'
84
+ else:
85
+ zip_name_pred = args.save_dir + args.split + '_pred.zip'
86
+
87
+ print(f"Saving results in {args.save_dir}")
88
+ print(f"file: {zip_name_pred}")
89
+
90
+ if args.use_gt:
91
+ transform = T.Resize((540, 960))
92
+ cam = FramebyFrameCalib(960, 540, denormalize=True)
93
+
94
+ with zipfile.ZipFile(zip_name_pred, 'w') as zip_file:
95
+ samples, complete = 0., 0.
96
+ for file in tqdm(files, desc="Processing Images"):
97
+ image = Image.open(file)
98
+ file_name = file.split('/')[-1].split('.')[0]
99
+ samples += 1
100
+
101
+ json_path = file.split('.')[0] + ".json"
102
+ f = open(json_path)
103
+ data = json.load(f)
104
+
105
+ kp_db = KeypointsDB(data, image)
106
+ line_db = LineKeypointsDB(data, image)
107
+ heatmaps, _ = kp_db.get_tensor_w_mask()
108
+ heatmaps = torch.tensor(heatmaps).unsqueeze(0)
109
+ heatmaps_l = line_db.get_tensor()
110
+ heatmaps_l = torch.tensor(heatmaps_l).unsqueeze(0)
111
+ kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:, :-1, :, :])
112
+ line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:, :-1, :, :])
113
+ kp_dict = coords_to_dict(kp_coords, threshold=0.01)
114
+ lines_dict = coords_to_dict(line_coords, threshold=0.01)
115
+
116
+ cam.update(kp_dict, lines_dict)
117
+ final_params_dict = cam.heuristic_voting()
118
+ # final_params_dict = cam.calibrate(5)
119
+
120
+ if final_params_dict:
121
+ complete += 1
122
+ cam_params = final_params_dict['cam_params']
123
+ print("heheheheheheh")
124
+ json_data = json.dumps(cam_params)
125
+ zip_file.writestr(f"camera_{file_name}.json", json_data)
126
+
127
+ else:
128
+ device = torch.device(args.cuda if torch.cuda.is_available() else 'cpu')
129
+ cfg = yaml.safe_load(open(args.cfg, 'r'))
130
+ cfg_l = yaml.safe_load(open(args.cfg_l, 'r'))
131
+
132
+ loaded_state = torch.load(args.weights_kp, map_location=device)
133
+ model = get_cls_net(cfg)
134
+ model.load_state_dict(loaded_state)
135
+ model.to(device)
136
+ model.eval()
137
+
138
+ loaded_state_l = torch.load(args.weights_line, map_location=device)
139
+ model_l = get_cls_net_l(cfg_l)
140
+ model_l.load_state_dict(loaded_state_l)
141
+ model_l.to(device)
142
+ model_l.eval()
143
+
144
+ transform = T.Resize((540, 960))
145
+ cam = FramebyFrameCalib(960, 540)
146
+
147
+ with zipfile.ZipFile(zip_name_pred, 'w') as zip_file:
148
+ samples, complete = 0., 0.
149
+ for file in tqdm(files, desc="Processing Images"):
150
+ image = Image.open(file)
151
+ file_name = file.split('/')[-1].split('.')[0]
152
+ samples += 1
153
+
154
+ with torch.no_grad():
155
+ image = f.to_tensor(image).float().to(device).unsqueeze(0)
156
+ image = image if image.size()[-1] == 960 else transform(image)
157
+ b, c, h, w = image.size()
158
+ heatmaps = model(image)
159
+ heatmaps_l = model_l(image)
160
+
161
+ kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:, :-1, :, :])
162
+ line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:, :-1, :, :])
163
+ kp_dict = coords_to_dict(kp_coords, threshold=args.kp_th)
164
+ lines_dict = coords_to_dict(line_coords, threshold=args.line_th)
165
+ kp_dict, lines_dict = complete_keypoints(kp_dict[0], lines_dict[0], w=w, h=h)
166
+
167
+ cam.update(kp_dict, lines_dict)
168
+ final_params_dict = cam.heuristic_voting(refine_lines=True)
169
+
170
+ if final_params_dict:
171
+ if final_params_dict['rep_err'] <= args.max_reproj_err:
172
+ complete += 1
173
+ cam_params = final_params_dict['cam_params']
174
+ json_data = json.dumps(cam_params)
175
+ zip_file.writestr(f"camera_{file_name}.json", json_data)
176
+
177
+ with zipfile.ZipFile(zip_name, 'w') as zip_file:
178
+ for file in tqdm(files, desc="Processing Images"):
179
+ file_name = file.split('/')[-1].split('.')[0]
180
+ data = json.load(open(file.split('.')[0] + ".json"))
181
+ json_data = json.dumps(data)
182
+ zip_file.writestr(f"{file_name}.json", json_data)
183
+
184
+ print(f'Completed {complete} / {samples}')
185
+
186
+
scripts/inference_tswc.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import glob
5
+ import yaml
6
+ import torch
7
+ import zipfile
8
+ import argparse
9
+ import warnings
10
+ import numpy as np
11
+ import torchvision.transforms as T
12
+ import torchvision.transforms.functional as f
13
+
14
+ from tqdm import tqdm
15
+ from PIL import Image
16
+
17
+ sys.path.insert(1, os.path.join(sys.path[0], '..'))
18
+ from model.cls_hrnet import get_cls_net
19
+ from model.cls_hrnet_l import get_cls_net as get_cls_net_l
20
+ from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, \
21
+ complete_keypoints, coords_to_dict
22
+ from utils.utils_keypoints import KeypointsDB
23
+ from utils.utils_lines import LineKeypointsDB
24
+
25
+
26
+ def parse_args():
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument("--cfg", type=str, required=True,
29
+ help="Path to the (kp model) configuration file")
30
+ parser.add_argument("--cfg_l", type=str, required=True,
31
+ help="Path to the (line model) configuration file")
32
+ parser.add_argument("--root_dir", type=str, required=True,
33
+ help="Root directory")
34
+ parser.add_argument("--split", type=str, required=True,
35
+ help="Dataset split")
36
+ parser.add_argument("--save_dir", type=str, required=True,
37
+ help="Root directory")
38
+ parser.add_argument("--weights_kp", type=str, required=True,
39
+ help="Model (keypoints) weigths to use")
40
+ parser.add_argument("--weights_line", type=str, required=True,
41
+ help="Model (lines) weigths to use")
42
+ parser.add_argument("--cuda", type=str, default="cuda:0",
43
+ help="CUDA device index (default: 'cuda:0')")
44
+ parser.add_argument("--kp_th", type=float, default="0.1")
45
+ parser.add_argument("--line_th", type=float, default="0.1")
46
+ parser.add_argument("--batch", type=int, default=1, help="Batch size")
47
+ parser.add_argument("--num_workers", type=int, default=4, help="Number of workers")
48
+
49
+
50
+ args = parser.parse_args()
51
+ return args
52
+
53
+
54
+ def get_files(file_paths):
55
+ jpg_files = []
56
+ for file_path in file_paths:
57
+ directory_path = os.path.join(os.path.join(args.root_dir, "Dataset/80_95"), file_path)
58
+ if os.path.exists(directory_path):
59
+ files = os.listdir(directory_path)
60
+ jpg_files.extend([os.path.join(directory_path, file) for file in files if file.endswith('.jpg')])
61
+
62
+ jpg_files = sorted(jpg_files)
63
+ return jpg_files
64
+
65
+ def get_homographies(file_paths):
66
+ npy_files = []
67
+ for file_path in file_paths:
68
+ directory_path = os.path.join(os.path.join(args.root_dir, "Annotations/80_95"), file_path)
69
+ if os.path.exists(directory_path):
70
+ files = os.listdir(directory_path)
71
+ npy_files.extend([os.path.join(directory_path, file) for file in files if file.endswith('.npy')])
72
+
73
+ npy_files = sorted(npy_files)
74
+ return npy_files
75
+
76
+
77
+ def make_file_name(file):
78
+ file = "TS-WorldCup/" + file.split("TS-WorldCup/")[-1]
79
+ splits = file.split('/')
80
+ side = splits[3]
81
+ match = splits[4]
82
+ image = splits[5]
83
+ frame = 'IMG_' + image.split('.')[0].split('_')[-1]
84
+ return side + '-' + match + '-' + frame
85
+
86
+
87
+ if __name__ == "__main__":
88
+ args = parse_args()
89
+
90
+ with open(args.root_dir + args.split + '.txt', 'r') as file:
91
+ # Read lines from the file and remove trailing newline characters
92
+ seqs = [line.strip() for line in file.readlines()]
93
+
94
+ files = get_files(seqs)
95
+ homographies = get_homographies(seqs)
96
+
97
+ zip_name_pred = args.save_dir + args.split + '_pred.zip'
98
+
99
+ device = torch.device(args.cuda if torch.cuda.is_available() else 'cpu')
100
+ cfg = yaml.safe_load(open(args.cfg, 'r'))
101
+ cfg_l = yaml.safe_load(open(args.cfg_l, 'r'))
102
+
103
+ loaded_state = torch.load(args.weights_kp, map_location=device)
104
+ model = get_cls_net(cfg)
105
+ model.load_state_dict(loaded_state)
106
+ model.to(device)
107
+ model.eval()
108
+
109
+ loaded_state_l = torch.load(args.weights_line, map_location=device)
110
+ model_l = get_cls_net_l(cfg_l)
111
+ model_l.load_state_dict(loaded_state_l)
112
+ model_l.to(device)
113
+ model_l.eval()
114
+
115
+ transform = T.Resize((540, 960))
116
+
117
+ with zipfile.ZipFile(zip_name_pred, 'w') as zip_file:
118
+ for count in tqdm(range(len(files)), desc="Processing Images"):
119
+ image = Image.open(files[count])
120
+ image = f.to_tensor(image).float().to(device).unsqueeze(0)
121
+ image = image if image.size()[-1] == 960 else transform(image)
122
+ b, c, h, w = image.size()
123
+
124
+
125
+ with torch.no_grad():
126
+ heatmaps = model(image)
127
+ heatmaps_l = model_l(image)
128
+
129
+ kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
130
+ line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
131
+ kp_dict = coords_to_dict(kp_coords, threshold=args.kp_th, ground_plane_only=True)
132
+ lines_dict = coords_to_dict(line_coords, threshold=args.line_th, ground_plane_only=True)
133
+ final_kp_dict, final_lines_dict = complete_keypoints(kp_dict[0], lines_dict[0],
134
+ w=w, h=h, normalize=True)
135
+ final_dict = {'kp': final_kp_dict, 'lines': final_lines_dict}
136
+
137
+ json_data = json.dumps(final_dict)
138
+ zip_file.writestr(f"{make_file_name(files[count])}.json", json_data)
139
+
140
+
141
+
142
+
143
+
144
+
145
+
scripts/inference_wc14.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import glob
5
+ from typing import final
6
+
7
+ import yaml
8
+ import torch
9
+ import zipfile
10
+ import argparse
11
+ import warnings
12
+ import numpy as np
13
+ import torchvision.transforms as T
14
+ import torchvision.transforms.functional as f
15
+
16
+ from tqdm import tqdm
17
+ from PIL import Image
18
+
19
+ sys.path.insert(1, os.path.join(sys.path[0], '..'))
20
+ from model.cls_hrnet import get_cls_net
21
+ from model.cls_hrnet_l import get_cls_net as get_cls_net_l
22
+ from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, \
23
+ complete_keypoints, coords_to_dict
24
+ from utils.utils_keypoints import KeypointsDB
25
+ from utils.utils_lines import LineKeypointsDB
26
+
27
+
28
+ def parse_args():
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument("--cfg", type=str, required=True,
31
+ help="Path to the (kp model) configuration file")
32
+ parser.add_argument("--cfg_l", type=str, required=True,
33
+ help="Path to the (line model) configuration file")
34
+ parser.add_argument("--root_dir", type=str, required=True,
35
+ help="Root directory")
36
+ parser.add_argument("--split", type=str, required=True,
37
+ help="Dataset split")
38
+ parser.add_argument("--save_dir", type=str, required=True,
39
+ help="Root directory")
40
+ parser.add_argument("--weights_kp", type=str, required=True,
41
+ help="Model (keypoints) weigths to use")
42
+ parser.add_argument("--weights_line", type=str, required=True,
43
+ help="Model (lines) weigths to use")
44
+ parser.add_argument("--cuda", type=str, default="cuda:0",
45
+ help="CUDA device index (default: 'cuda:0')")
46
+ parser.add_argument("--kp_th", type=float, default="0.1")
47
+ parser.add_argument("--line_th", type=float, default="0.1")
48
+ parser.add_argument("--batch", type=int, default=1, help="Batch size")
49
+ parser.add_argument("--num_workers", type=int, default=4, help="Number of workers")
50
+ parser.add_argument('--use_gt', action='store_true', help='Use ground truth (default: False)')
51
+
52
+
53
+ args = parser.parse_args()
54
+ return args
55
+
56
+
57
+
58
+ if __name__ == "__main__":
59
+ args = parse_args()
60
+
61
+ files = glob.glob(os.path.join(args.root_dir + args.split, "*.jpg"))
62
+
63
+ if args.use_gt:
64
+ zip_name_pred = args.save_dir + args.split + '_gt.zip'
65
+ else:
66
+ zip_name_pred = args.save_dir + args.split + '_pred.zip'
67
+
68
+
69
+ if args.use_gt:
70
+ device = torch.device(args.cuda if torch.cuda.is_available() else 'cpu')
71
+
72
+ with zipfile.ZipFile(zip_name_pred, 'w') as zip_file:
73
+ for file in tqdm(files, desc="Processing Images"):
74
+ image = Image.open(file)
75
+ w, h = image.size
76
+
77
+ homography_file = args.root_dir + args.split + '/' + \
78
+ file.split('/')[-1].split('.')[0] + '.homographyMatrix'
79
+
80
+ json_path = file.split('.')[0] + ".json"
81
+ f = open(json_path)
82
+ data = json.load(f)
83
+ kp_db = KeypointsDB(data, image)
84
+ line_db = LineKeypointsDB(data, image)
85
+ heatmaps, _ = kp_db.get_tensor_w_mask()
86
+ heatmaps = torch.tensor(heatmaps).unsqueeze(0)
87
+ heatmaps_l = line_db.get_tensor()
88
+ heatmaps_l = torch.tensor(heatmaps_l).unsqueeze(0)
89
+ kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
90
+ line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
91
+ kp_dict = coords_to_dict(kp_coords, threshold=0.1)
92
+ lines_dict = coords_to_dict(line_coords, threshold=0.1)
93
+ final_kp_dict = complete_keypoints(kp_dict, lines_dict, w=w, h=h, normalize=True)
94
+ final_dict = {'kp': final_kp_dict, 'lines': lines_dict}
95
+
96
+ json_data = json.dumps(final_dict)
97
+ zip_file.writestr(f"{file.split('/')[-1].split('.')[0]}.json", json_data)
98
+
99
+
100
+ else:
101
+ device = torch.device(args.cuda if torch.cuda.is_available() else 'cpu')
102
+ cfg = yaml.safe_load(open(args.cfg, 'r'))
103
+ cfg_l = yaml.safe_load(open(args.cfg_l, 'r'))
104
+
105
+ loaded_state = torch.load(args.weights_kp, map_location=device)
106
+ model = get_cls_net(cfg)
107
+ model.load_state_dict(loaded_state)
108
+ model.to(device)
109
+ model.eval()
110
+
111
+ loaded_state_l = torch.load(args.weights_line, map_location=device)
112
+ model_l = get_cls_net_l(cfg_l)
113
+ model_l.load_state_dict(loaded_state_l)
114
+ model_l.to(device)
115
+ model_l.eval()
116
+
117
+ transform = T.Resize((540, 960))
118
+
119
+ with zipfile.ZipFile(zip_name_pred, 'w') as zip_file:
120
+ for file in tqdm(files, desc="Processing Images"):
121
+ image = Image.open(file)
122
+ image = f.to_tensor(image).float().to(device).unsqueeze(0)
123
+ image = image if image.size()[-1] == 960 else transform(image)
124
+ b, c, h, w = image.size()
125
+
126
+ homography_file = args.root_dir + args.split + '/' + \
127
+ file.split('/')[-1].split('.')[0] + '.homographyMatrix'
128
+
129
+ with torch.no_grad():
130
+ heatmaps = model(image)
131
+ heatmaps_l = model_l(image)
132
+
133
+ kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
134
+ line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
135
+ kp_dict = coords_to_dict(kp_coords, threshold=args.kp_th, ground_plane_only=True)
136
+ lines_dict = coords_to_dict(line_coords, threshold=args.line_th, ground_plane_only=True)
137
+ final_kp_dict, final_lines_dict = complete_keypoints(kp_dict[0], lines_dict[0],
138
+ w=w, h=h, normalize=True)
139
+ final_dict = {'kp': final_kp_dict, 'lines': final_lines_dict}
140
+
141
+ json_data = json.dumps(final_dict)
142
+ zip_file.writestr(f"{file.split('/')[-1].split('.')[0]}.json", json_data)
143
+
144
+
145
+
146
+
147
+
148
+
scripts/run_pipeline_sn22.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Set parameters
4
+ ROOT_DIR="calibration/"
5
+ SPLIT="test"
6
+ CFG="config/hrnetv2_w48.yaml"
7
+ CFG_L="config/hrnetv2_w48_l.yaml"
8
+ WEIGHTS_KP="weights/SV_kp"
9
+ WEIGHTS_L="weights/SV_lines"
10
+ SAVE_DIR="inference/inference_3D/inference_sn22/"
11
+ DEVICE="cuda:0"
12
+ KP_TH=0.1611
13
+ LINE_TH=0.3434
14
+ MAX_REPROJ_ERR=57
15
+ GT_FILE="${SAVE_DIR}${SPLIT}_main.zip"
16
+ PRED_FILE="${SAVE_DIR}${SPLIT}_main_pred.zip"
17
+
18
+
19
+ # Run inference script
20
+ python scripts/inference_sn.py --cfg $CFG --cfg_l $CFG_L --weights_kp $WEIGHTS_KP --weights_line $WEIGHTS_L --root_dir $ROOT_DIR --split $SPLIT --save_dir $SAVE_DIR --kp_th $KP_TH --line_th $LINE_TH --max_reproj_err $MAX_REPROJ_ERR --cuda $DEVICE --main_cam_only
21
+
22
+ # Run evaluation script
23
+ python sn_calibration/src/evalai_camera.py -s $GT_FILE -p $PRED_FILE
scripts/run_pipeline_sn23.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Set parameters
4
+ ROOT_DIR="datasets/calibration-2023/"
5
+ SPLIT="test"
6
+ CFG="config/hrnetv2_w48.yaml"
7
+ CFG_L="config/hrnetv2_w48_l.yaml"
8
+ WEIGHTS_KP="weights/MV_kp"
9
+ WEIGHTS_L="weights/MV_lines"
10
+ SAVE_DIR="inference/inference_3D/inference_sn23/"
11
+ DEVICE="cuda:0"
12
+ KP_TH=0.0712
13
+ LINE_TH=0.2571
14
+ MAX_REPROJ_ERR=38
15
+ GT_FILE="${SAVE_DIR}${SPLIT}.zip"
16
+ PRED_FILE="${SAVE_DIR}${SPLIT}_pred.zip"
17
+
18
+
19
+ # Run inference script
20
+ python scripts/inference_sn.py --cfg $CFG --cfg_l $CFG_L --weights_kp $WEIGHTS_KP --weights_line $WEIGHTS_L --root_dir $ROOT_DIR --split $SPLIT --save_dir $SAVE_DIR --kp_th $KP_TH --line_th $LINE_TH --max_reproj_err $MAX_REPROJ_ERR --cuda $DEVICE
21
+
22
+ # Run evaluation script
23
+ python sn_calibration/src/evalai_camera.py -s $GT_FILE -p $PRED_FILE
scripts/run_pipeline_tswc.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Set parameters
4
+ ROOT_DIR="datasets/TS-WorldCup/"
5
+ SPLIT="test"
6
+ CFG="config/hrnetv2_w48.yaml"
7
+ CFG_L="config/hrnetv2_w48_l.yaml"
8
+ WEIGHTS_KP="weights/SV_FT_TSWC_kp"
9
+ WEIGHTS_L="weights/SV_FT_TSWC_lines"
10
+ SAVE_DIR="inference/inference_2D/inference_tswc/"
11
+ KP_TH=0.2432
12
+ LINE_TH=0.8482
13
+ DEVICE="cuda:0"
14
+ PRED_FILE="${SAVE_DIR}${SPLIT}_pred.zip"
15
+
16
+ # Run inference script
17
+ python scripts/inference_tswc.py --cfg $CFG --cfg_l $CFG_L --weights_kp $WEIGHTS_KP --weights_line $WEIGHTS_L --root_dir $ROOT_DIR --split $SPLIT --save_dir $SAVE_DIR --kp_th $KP_TH --line_th $LINE_TH --cuda $DEVICE
18
+
19
+ # Run evaluation script
20
+ python scripts/eval_tswc.py --root_dir $ROOT_DIR --split $SPLIT --pred_file $PRED_FILE
scripts/run_pipeline_wc14.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Set parameters
4
+ ROOT_DIR="datasets/WC-2014/"
5
+ SPLIT="test"
6
+ CFG="config/hrnetv2_w48.yaml"
7
+ CFG_L="config/hrnetv2_w48_l.yaml"
8
+ WEIGHTS_KP="weights/SV_FT_WC14_kp"
9
+ WEIGHTS_L="weights/SV_FT_WC14_lines"
10
+ SAVE_DIR="inference/inference_2D/inference_wc14/"
11
+ KP_TH=0.1274
12
+ LINE_TH=0.1439
13
+ DEVICE="cuda:0"
14
+ PRED_FILE="${SAVE_DIR}${SPLIT}_pred.zip"
15
+
16
+ # Run inference script
17
+ python scripts/inference_wc14.py --cfg $CFG --cfg_l $CFG_L --weights_kp $WEIGHTS_KP --weights_line $WEIGHTS_L --root_dir $ROOT_DIR --split $SPLIT --save_dir $SAVE_DIR --kp_th $KP_TH --line_th $LINE_TH --cuda $DEVICE
18
+
19
+ # Run evaluation script
20
+ python scripts/eval_wc14.py --root_dir $ROOT_DIR --split $SPLIT --pred_file $PRED_FILE
scripts/run_pipeline_wc14_3D.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Set parameters
4
+ ROOT_DIR="datasets/WC-2014/"
5
+ SPLIT="test"
6
+ CFG="config/hrnetv2_w48.yaml"
7
+ CFG_L="config/hrnetv2_w48_l.yaml"
8
+ WEIGHTS_KP="weights/SV_kp"
9
+ WEIGHTS_L="weights/SV_lines"
10
+ SAVE_DIR="inference/inference_3D/inference_wc14/"
11
+ DEVICE="cuda:0"
12
+ KP_TH=0.0070
13
+ LINE_TH=0.1513
14
+ MAX_REPROJ_ERR=83
15
+ GT_FILE="${SAVE_DIR}${SPLIT}_main.zip"
16
+ PRED_FILE="${SAVE_DIR}${SPLIT}_main_pred.zip"
17
+
18
+
19
+ # Run inference script
20
+ python scripts/inference_sn.py --cfg $CFG --cfg_l $CFG_L --weights_kp $WEIGHTS_KP --weights_line $WEIGHTS_L --root_dir $ROOT_DIR --split $SPLIT --save_dir $SAVE_DIR --kp_th $KP_TH --line_th $LINE_TH --max_reproj_err $MAX_REPROJ_ERR --cuda $DEVICE --main_cam_only
21
+
22
+ # Run evaluation script
23
+ python sn_calibration/src/evalai_camera.py -s $GT_FILE -p $PRED_FILE
sn_calibration/ChallengeRules.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Guidelines for the Calibration challenge
2
+
3
+ The 2nd [Calibration challenge]() will be held at the
4
+ official [CVSports Workshop](https://vap.aau.dk/cvsports/) at CVPR 2023!
5
+ Subscribe (watch) the repo to receive the latest info regarding timeline and prizes!
6
+
7
+ We provide an [evaluation server](https://eval.ai/web/challenges/challenge-page/1946/overview) for anyone competing in the challenge.
8
+ This evaluation server handles predictions for the open **test** set and the segregated **challenge** set.
9
+
10
+ Winners will be announced at CVSports Workshop at CVPR 2022.
11
+ Prizes 💲💲💲 include $1000 cash award, sponsored by [EVS Broadcast Equipment](https://evs.com/).
12
+
13
+
14
+ ## Who can participate / How to participate?
15
+
16
+ - Any individual can participate to the challenge, except the organizers.
17
+ - The participants are recommended to form a team to participate.
18
+ - Each team can have one or more members.
19
+ - An individual/team can compete on both task.
20
+ - An individual associated with multiple teams (for a given task) or a team with multiple accounts will be disqualified.
21
+ - On both task, a particpant can only use the images as input.
22
+ - A particpant is allowed to extract its own visual/audio features with any pre-trained model.
23
+
24
+ ## How to win / What is the prize?
25
+
26
+ - For each task, the winner is the individual/team who reach the highest performance on the **challenge** set.
27
+ - The metrics taken into consideration is the combined metric: completeness * Accuracy@5.
28
+ - The deadline to submit your results is May 30th at 11.59 pm Pacific Time.
29
+ - The teams that perform best will be granted $1000 from our sponsor [EVS Broadcast Equipment](https://evs.com/).
30
+ - In order to be eligible for the prize, we require the individual/team to provide a short report describing the details of the methodology (CVPR format, max 2 pages)
31
+
32
+
33
+
34
+ ## Important dates
35
+
36
+ Note that these dates are tentative and subject to changes if necessary.
37
+
38
+ - **February 1:** Open evaluation server on the (Open) Test set.
39
+ - **February 15:** Open evaluation server on the (Seggregated) Challenge set.
40
+ - **May 30:** Close evaluation server.
41
+ - **June 6:** Deadline for submitting the report.
42
+ - **June 19:** A full-day workshop at CVPR 2023.
43
+
44
+ For any further doubt or concern, please raise an issue in that repository, or contact us directly on [Discord](https://discord.gg/SM8uHj9mkP).
sn_calibration/README.md ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ![](./doc/EVS_LOGO_COLOUR_RGB.png)
3
+
4
+ # EVS Camera Calibration Challenge
5
+
6
+
7
+ Welcome to the EVS camera calibration challenge ! This challenge is sponsored by EVS Broadcast Equipment, and is
8
+ developped in collaboration with the SoccerNet team.
9
+
10
+ This challenge consists of two distinct tasks which are defined hereafter. We provide sample code and baselines for each
11
+ task to help you get started!
12
+
13
+
14
+ Participate in our upcoming Challenge at the [CVSports](https://vap.aau.dk/cvsports/) workshop at CVPR and try to win up to 1000$ sponsored by [EVS](https://evs.com/)! All details are available on the [challenge website](https://eval.ai/web/challenges/challenge-page/1537/overview ), or on the [main page](https://www.soccer-net.org/).
15
+
16
+ The participation deadline is fixed at the 30th of May 2023. The official rules and guidelines are available on [ChallengeRules.md](ChallengeRules.md).
17
+
18
+ <a href="https://www.youtube.com/watch?v=nxywN6X_0oE">
19
+ <p align="center"><img src="images/Thumbnail.png" width="720"></p>
20
+ </a>
21
+
22
+ ### 2023 Leaderboard
23
+
24
+ <p><table class="dataframe">
25
+ <thead>
26
+ <tr style="text-align: right;">
27
+ <th style = "background-color: #FFFFFF;font-family: Century Gothic, sans-serif;font-size: medium;color: #305496;text-align: left;border-bottom: 2px solid #305496;padding: 0px 20px 0px 0px;width: auto">Team</th>
28
+ <th style = "background-color: #FFFFFF;font-family: Century Gothic, sans-serif;font-size: medium;color: #305496;text-align: left;border-bottom: 2px solid #305496;padding: 0px 20px 0px 0px;width: auto">Combined Metric</th>
29
+ <th style = "background-color: #FFFFFF;font-family: Century Gothic, sans-serif;font-size: medium;color: #305496;text-align: left;border-bottom: 2px solid #305496;padding: 0px 20px 0px 0px;width: auto">Accuracy@5</th>
30
+ <th style = "background-color: #FFFFFF;font-family: Century Gothic, sans-serif;font-size: medium;color: #305496;text-align: left;border-bottom: 2px solid #305496;padding: 0px 20px 0px 0px;width: auto">Completeness</th>
31
+ </tr>
32
+ </thead>
33
+ <tbody>
34
+ <tr>
35
+ <td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">Sportlight</td>
36
+ <td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">0.55</td>
37
+ <td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">73.22</td>
38
+ <td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">75.59</td>
39
+ </tr>
40
+ <tr>
41
+ <td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">Spiideo</td>
42
+ <td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">0.53</td>
43
+ <td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">52.95</td>
44
+ <td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">99.96</td>
45
+ </tr>
46
+ <tr>
47
+ <td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">SAIVA_Calibration</td>
48
+ <td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">0.53</td>
49
+ <td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">60.33</td>
50
+ <td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">87.22</td>
51
+ </tr>
52
+ <tr>
53
+ <td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">BPP</td>
54
+ <td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">0.5</td>
55
+ <td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">69.12</td>
56
+ <td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">72.54</td>
57
+ </tr>
58
+ <tr>
59
+ <td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">ikapetan</td>
60
+ <td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">0.43</td>
61
+ <td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">53.78</td>
62
+ <td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">79.71</td>
63
+ </tr>
64
+ <tr>
65
+ <td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">NASK</td>
66
+ <td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">0.41</td>
67
+ <td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">53.01</td>
68
+ <td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">77.81</td>
69
+ </tr>
70
+ <tr>
71
+ <td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">Mike Azatov and Jonas Theiner</td>
72
+ <td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">0.41</td>
73
+ <td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">58.61</td>
74
+ <td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">69.34</td>
75
+ </tr>
76
+ <tr>
77
+ <td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">Baseline</td>
78
+ <td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">0.08</td>
79
+ <td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">13.54</td>
80
+ <td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">61.54</td>
81
+ </tr>
82
+ </tbody>
83
+ </table></p>
84
+
85
+ ## Table of content
86
+
87
+ - Install
88
+ - Dataset
89
+ - Soccer pitch annotations
90
+ - Camera calibration
91
+ - Definition
92
+ - Evaluation
93
+ - Baseline
94
+
95
+ ## Install
96
+
97
+ Download the network weights from google drive https://drive.google.com/file/d/1dbN7LdMV03BR1Eda8n7iKNIyYp9r07sM/view?usp=sharing
98
+ and place them in [resources](resources).
99
+
100
+ With python 3 already installed, the python environment can be installed with the following command:
101
+
102
+ ```
103
+ pip install -r requirements.txt
104
+ ```
105
+
106
+ ## Dataset
107
+
108
+ SoccerNet is a dataset containing 400 broadcasted videos of whole soccer games. The dataset can be found
109
+ here : https://soccer-net.org/
110
+ The authors give access to template code and a python package to get started with the dataset. All the documentation
111
+ about these tools can be found here : https://github.com/SilvioGiancola/SoccerNetv2-DevKit
112
+
113
+ All the data needed for challenge can be downloaded with these lines :
114
+
115
+ ```python
116
+ from SoccerNet.Downloader import SoccerNetDownloader as SNdl
117
+ soccerNetDownloader = SNdl(LocalDirectory="path/to/SoccerNet")
118
+ soccerNetDownloader.downloadDataTask(task="calibration-2023", split=["train","valid","test","challenge"])
119
+ ```
120
+
121
+ Historically, the dataset was first released for an action spotting task. In its first version, the images corresponding
122
+ to soccer actions (goals, fouls, etc) were identified. In the following editions, more annotations have been associated
123
+ to those images. In the last version of the dataset (SoccerNetV3), the extremities of the lines of the soccer pitch
124
+ markings have been annotated. As a partnership with SoccerNet's team, we use these annotations in a new challenge. The
125
+ challenge is divided in two tasks, the resolution of the first leading to the second one. The first is a soccer pitch
126
+ element localisation task which can then be used for the second task which is a camera calibration task.
127
+
128
+ **/!\ New** : some annotations have been added: for some images, there are new points annotation along the pitch markings lines.
129
+ For straight pitch marking lines, you can always assume that the extremities are annotated, and sometimes,
130
+ if the image has been reannotated, there will be a few extra points along the imaged line.
131
+
132
+
133
+ ### Soccer pitch annotations
134
+
135
+ Performing camera calibration can be eased by the presence of an object with a known shape in the image. For soccer
136
+ content, the soccer pitch can be used as a target for the camera calibration because it has a known shape and its
137
+ dimensions are specified in the International Football Association Board's law of the
138
+ game (https://digitalhub.fifa.com/m/5371a6dcc42fbb44/original/d6g1medsi8jrrd3e4imp-pdf.pdf).
139
+
140
+ Moreover, we define a set of semantic labels for each semantic element of the soccer pitch. We also define the bottom
141
+ side of the pitch as the one where the main and 16 meters broadcast cameras are installed.
142
+
143
+ ![](./doc/soccernet_classes.png)
144
+
145
+ 1. Big rect. left bottom,
146
+ 2. Big rect. left main,
147
+ 3. Big rect. left top,
148
+ 4. Big rect. right bottom,
149
+ 5. Big rect. right main,
150
+ 6. Big rect. right top,
151
+ 7. Circle central,
152
+ 8. Circle left,
153
+ 9. Circle right,
154
+ 10. Goal left crossbar,
155
+ 11. Goal left post left ,
156
+ 12. Goal left post right,
157
+ 13. Goal right crossbar,
158
+ 14. Goal right post left,
159
+ 15. Goal right post right,
160
+ 16. Middle line,
161
+ 17. Side line bottom,
162
+ 18. Side line left,
163
+ 19. Side line right,
164
+ 20. Side line top,
165
+ 21. Small rect. left bottom,
166
+ 22. Small rect. left main,
167
+ 23. Small rect. left top,
168
+ 24. Small rect. right bottom,
169
+ 25. Small rect. right main,
170
+ 26. Small rect. right top
171
+
172
+ In the third version of SoccerNet, there are new annotations for each image of the dataset. These new annotations
173
+ consists in the list of all extremities of the semantic elements of the pitch present in the image. The extremities are
174
+ a pair of 2D point coordinates.
175
+
176
+ For the circles drawn on the pitch, the annotations consist in a list of points which give roughly the shape of the
177
+ circle when connected. Note that due to new annotations, the sequential order of circle points is no longer guaranteed.
178
+
179
+
180
+
181
+ ## Camera calibration task
182
+
183
+ As a soccer pitch has known dimensions, it is possible to use the soccer pitch as a calibration target in order to
184
+ calibrate the camera. Since the pitch is roughly planar, we can model the transformation applied by the camera to the pitch by a
185
+ homography. In order to estimate the homography between the pitch and its image, we only need 4 lines. We provide a
186
+ baseline in order to extract camera parameters from this kind of images.
187
+
188
+ ### Definition
189
+
190
+ In this task, we ask you to provide valid camera parameters for each image of the challenge set.
191
+
192
+ Given a common 3D pitch template, we will use the camera parameters produced by your algorithm in order to estimate the
193
+ reprojection error induced by the camera parameters. The camera parameters include its lens parameters, its orientation,
194
+ its translation with respect to the world reference axis system that we define accordingly:
195
+
196
+ ![](./doc/axis-system-resized.jpeg)
197
+
198
+ #### Rotation convention
199
+
200
+ Following Euler angles convention, we use a ZXZ succession of intrinsic rotations in order to describe the orientation
201
+ of the camera. Starting from the world reference axis system, we first apply a rotation around the Z axis to pan the
202
+ camera. Then the obtained axis system is rotated around its x axis in order to tilt the camera. Then the last rotation
203
+ around the z axis of the new axis system allows to roll the camera. Note that this z axis is the principal axis of the
204
+ camera.
205
+
206
+ The lens parameters produced must follow the pinhole model. Additionally, the parameters can include radial, tangential
207
+ and thin prism distortion parameters. This corresponds to the full model of
208
+ OpenCV : https://docs.opencv.org/4.5.0/d9/d0c/group__calib3d.html#details
209
+
210
+ For each image of the test set, we expect to receive a json file named "**camera_{frame_index}.json**" containing a
211
+ dictionary with the camera parameters.
212
+
213
+ ```
214
+ # camera_00001.json
215
+ {
216
+ "pan_degrees": 14.862476218376278,
217
+ "tilt_degrees": 78.83988009048775,
218
+ "roll_degrees": -2.2210919345134497,
219
+ "position_meters": [
220
+ 32.6100008989567,
221
+ 67.9363036953344,
222
+ -14.898134157887508
223
+ ],
224
+ "x_focal_length": 3921.6013418112757,
225
+ "y_focal_length": 3921.601341812138,
226
+ "principal_point": [
227
+ 480.0,
228
+ 270.0
229
+ ],
230
+ "radial_distortion": [
231
+ 0.0,
232
+ 0.0,
233
+ 0.0,
234
+ 0.0,
235
+ 0.0,
236
+ 0.0
237
+ ],
238
+ "tangential_distortion": [
239
+ 0.0,
240
+ 0.0
241
+ ],
242
+ "thin_prism_distortion": [
243
+ 0.0,
244
+ 0.0,
245
+ 0.0,
246
+ 0.0
247
+ ]
248
+ }
249
+ ```
250
+
251
+ The results should be organized as follows :
252
+
253
+
254
+ ```
255
+ test.zip
256
+ |__ camera_00001.json
257
+ |__ camera_00002.json
258
+ ```
259
+
260
+
261
+ ### Evaluation
262
+
263
+ We only evaluate the quality of the camera parameters provided in the image world, as it is the only groundtruth that we
264
+ have. The annotated points are either extremities or points on circles, and thus they do not always have a mapping to a
265
+ 3D point, but we can always map them to lines or circles.
266
+
267
+ #### Dealing with ambiguities
268
+
269
+ The dataset contains some ambiguities when we consider each image independently. Without context, one can't know if a
270
+ camera behind a goal is on the left or the right side of the pitch.
271
+
272
+ For example, this image taken by a camera behind the goal :
273
+
274
+ ![](./doc/tactical_amibguity_resized.jpeg)
275
+
276
+ It is impossible to say without any context if it was shot by the camera whose filming range is in blue or in yellow.
277
+
278
+ ![](./doc/ambiguity_soccernet_resized.jpeg)
279
+
280
+ Therefore, we take this ambiguity into account in the evaluation and we consider both accuracy for the groundtruth
281
+ label, and the accuracy for the central projection by the pitch center of the labels. The higher accuracy will be
282
+ selected for your evaluation.
283
+
284
+
285
+ We evaluate the best submission based on the accuracy at a specific threshold distance. This metric is explained
286
+ hereunder.
287
+
288
+ #### Accuracy @ threshold
289
+
290
+ The evaluation is based on the reprojection error which we define here as the L2 distance between one annotated point
291
+ and the line to which the point belong. This metric does not account well for false positives and false negatives
292
+ (hallucinated/missing lines projections). Thus we formulate our evaluation as a binary classification, with a
293
+ distance threshold with a twist : this time, we consider a pitch marking to be one entity, and for it to be correctly
294
+ detected, all its extremities (or all points annotated for circles) must have a reprojection error smaller than the
295
+ threshold.
296
+
297
+ As we allow lens distortion, the projection of the pitch line markings can be curvated. This is why we sample the pitch
298
+ model every few centimeters, and we consider that the distance between a projected pitch marking and a groundtruth point is in
299
+ fact the euclidian distance between the groundtruth point and the polyline given by the projection of sampled points.
300
+
301
+ * True positives : for classes that belong both to the prediction and the groundtruth, a predicted element is a True
302
+ Positive if all the L2 distances between its groundtruth points and the predicted polyline are lower than a
303
+ certain threshold.
304
+
305
+ ![](./doc/tp-condition-2.png)
306
+
307
+ * False positives : contains elements that were detected with a class that do not belong to the groundtruth classes, and
308
+ elements with valid classes which are distant from at least **t** pixels from one of the groundtruth points associated
309
+ to the element.
310
+ * False negatives: Line elements only present in the groundtruth are counted as False Negatives.
311
+ * True negatives : There are no True Negatives.
312
+
313
+ The Accuracy for a threshold of t pixels is given by : **Acc@t = TP/(TP+FN+FP)**. We evaluate the accuracy at 5 pixels.
314
+ We only use images with predicted camera parameters in this evaluation.
315
+
316
+ #### Completeness rate
317
+
318
+ We also measure the completeness rate as the number of camera parameters provided divided by the number of images with
319
+ more than four semantic line annotations in the dataset.
320
+
321
+ #### Final score
322
+
323
+ The evaluation criterion for a camera calibration method is the following : **Completeness x Acc@5**
324
+
325
+ #### Per class information
326
+
327
+ The global accuracy described above has the advantage to treat all soccer pitch marking types equivalently even if the
328
+ groundtruth contains more annotated points for a specific class. For instance, a circle is annotated with 9 points on
329
+ average, whilst rectilinear elements have usually two points annotated. But this metric might be harder as we consider
330
+ all points annotated instead of each of them independently. This is precisely why we propose this per class metric, that
331
+ accounts for each point annotated separately. This metric is only for information purposes and will not be used in the
332
+ ranking of submissions.
333
+
334
+ The prediction is obtained by sampling each 3D real world pitch element every few centimeters, which means that the
335
+ number of points in the prediction may be variable for the same camera parameters. This gives a very high number of
336
+ predicted points for a certain class, and thus we find a workaround to count false positives.
337
+
338
+ The confusion matrices are computed per class in the following way :
339
+
340
+ * True positives : for classes that belong both to the prediction and the groundtruth, a predicted point is counted in
341
+ the True Positives if the L2 distance from this groundtruth point to the predicted polyline is lower than a
342
+ certain threshold **t**.
343
+
344
+ * False positives : counts groundtruth points that have a distance to the corresponding predicted polyline that is higher
345
+ than the threshold value **t**. For predicted lines that do not belong to the groundtruth, we can not count every
346
+ predicted point as a false positive because the number of points depend on the sampling factor that has to be high
347
+ enough, which can lower a lot our metric. We decided arbitrarily to count for this kind of false positive the number
348
+ of points that a human annotator would have annotated for this class, i.e. two points for rectilinear elements and 9
349
+ for circle elements.
350
+ * False negatives: All groundtruth points whose class is only present in the groundtruth are counted as False Negatives.
351
+ * True negatives : There are no True Negatives.
352
+
353
+ The Accuracy for a threshold of t pixels is given by : **Acc@t = TP/(TP+FN+FP)**. We evaluate the accuracy at 5, 10 and
354
+ 20 pixels. We only use images with predicted camera parameters in this evaluation.
355
+
356
+
357
+ ### Baseline
358
+
359
+ #### Method
360
+
361
+ For our camera calibration baseline, we proceed in two steps: first we find the pitch markings location in the image,
362
+ and then given our pitch marking correspondences, we estimate camera parameters.
363
+
364
+ For this first step, we decided to locate the pitch markings with a neural network trained to perform semantic line
365
+ segmentation. We used DeepLabv3 architecture. The target semantic segmentation masks were generated by joining
366
+ successively all the points annotated for each line in the image. We provide the dataloader that we used for the
367
+ training in the src folder.
368
+
369
+ The segmentation maps predicted by the neural
370
+ network are further processed in order to get the line extremities. First, for each class, we fit circles on the
371
+ segmentation mask. All pixels belonging to a same class are thus synthesized by a set of points (i.e. circles centers).
372
+ Then we build polylines based on the circles centers : all the points that are close enough are considered to belong to
373
+ the same polyline. Finally the extremities of each line class will be the extremities of the longest polyline for that
374
+ line class.
375
+
376
+ You can test the line detection with the following code:
377
+
378
+ ```
379
+ python src/detect_extremities.py -s <path_to_soccernet_dataset> -p <path_to_store_predictions>
380
+ ```
381
+
382
+
383
+ In the second step, we use the extremities of the lines (
384
+ not Circles) detected in the image in order to estimate an homography from the soccer pitch model to the image. We
385
+ provide a class **soccerpitch.py** to define the pitch model according to the rules of the game. The homography is then
386
+ decomposed in camera parameters. All aspects concerning camera parameters are located in the **camera.py**, including
387
+ homography decomposition in rotation and translation matrices, including calibration matrix estimation from the
388
+ homography, functions for projection...
389
+
390
+ You can test the baseline with the following line :
391
+
392
+ ```python src/baseline_cameras.py -s <path to soccernet dataset> -p <path to 1st task prediction>```
393
+
394
+ And to test the evaluation, you can run :
395
+
396
+ `python src/evaluate_camera.py -s <path to soccernet dataset> -p <path to predictions> -t <threshold value>`
397
+
398
+
399
+ #### Results
400
+
401
+ | Acc@t | Acc@5 | Completeness | Final score |
402
+ |----------|-------|--------------|-------------|
403
+ | Baseline | 11.7% | 68% | 7.96% |
404
+
405
+ #### Improvements
406
+
407
+ The baseline could be directly improved by :
408
+
409
+ * exploiting the masks rather than the extremities prediction of the first baseline
410
+ * using ransac to estimate the homography
411
+ * refining the camera parameters using line and ellipses correspondences.
412
+
413
+ ## Citation
414
+ For further information check out the paper and supplementary material:
415
+ https://arxiv.org/abs/2210.02365
416
+
417
+ Please cite our work if you use the SoccerNet dataset:
418
+ ```bibtex
419
+ @inproceedings{Giancola_2022,
420
+ doi = {10.1145/3552437.3558545},
421
+ url = {https://doi.org/10.1145%2F3552437.3558545},
422
+ year = 2022,
423
+ month = {oct},
424
+ publisher = {{ACM}},
425
+ author = {Silvio Giancola and Anthony Cioppa and Adrien Deli{\`{e}}ge and Floriane Magera and Vladimir Somers and Le Kang and Xin Zhou and Olivier Barnich and Christophe De Vleeschouwer and Alexandre Alahi and Bernard Ghanem and Marc Van Droogenbroeck and Abdulrahman Darwish and Adrien Maglo and Albert Clap{\'{e}}s and Andreas Luyts and Andrei Boiarov and Artur Xarles and Astrid Orcesi and Avijit Shah and Baoyu Fan and Bharath Comandur and Chen Chen and Chen Zhang and Chen Zhao and Chengzhi Lin and Cheuk-Yiu Chan and Chun Chuen Hui and Dengjie Li and Fan Yang and Fan Liang and Fang Da and Feng Yan and Fufu Yu and Guanshuo Wang and H. Anthony Chan and He Zhu and Hongwei Kan and Jiaming Chu and Jianming Hu and Jianyang Gu and Jin Chen and Jo{\~{a}}o V. B. Soares and Jonas Theiner and Jorge De Corte and Jos{\'{e}} Henrique Brito and Jun Zhang and Junjie Li and Junwei Liang and Leqi Shen and Lin Ma and Lingchi Chen and Miguel Santos Marques and Mike Azatov and Nikita Kasatkin and Ning Wang and Qiong Jia and Quoc Cuong Pham and Ralph Ewerth and Ran Song and Rengang Li and Rikke Gade and Ruben Debien and Runze Zhang and Sangrok Lee and Sergio Escalera and Shan Jiang and Shigeyuki Odashima and Shimin Chen and Shoichi Masui and Shouhong Ding and Sin-wai Chan and Siyu Chen and Tallal El-Shabrawy and Tao He and Thomas B. Moeslund and Wan-Chi Siu and Wei Zhang and Wei Li and Xiangwei Wang and Xiao Tan and Xiaochuan Li and Xiaolin Wei and Xiaoqing Ye and Xing Liu and Xinying Wang and Yandong Guo and Yaqian Zhao and Yi Yu and Yingying Li and Yue He and Yujie Zhong and Zhenhua Guo and Zhiheng Li},
426
+ title = {{SoccerNet} 2022 Challenges Results},
427
+ booktitle = {Proceedings of the 5th International {ACM} Workshop on Multimedia Content Analysis in Sports}
428
+ }
429
+ ```
430
+
431
+
432
+
433
+ ## Our other Challenges
434
+
435
+ Check out our other challenges related to SoccerNet!
436
+ - [Action Spotting](https://github.com/SoccerNet/sn-spotting)
437
+ - [Replay Grounding](https://github.com/SoccerNet/sn-grounding)
438
+ - [Calibration](https://github.com/SoccerNet/sn-calibration)
439
+ - [Re-Identification](https://github.com/SoccerNet/sn-reid)
440
+ - [Tracking](https://github.com/SoccerNet/sn-tracking)
sn_calibration/requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # numpy~>1.21
2
+ # torch~=1.10.0
3
+ # torchvision~=0.11.1
4
+ # pillow~>9.0.0
5
+ # tqdm~=4.62.3
6
+ # SoccerNet>=0.1.23
7
+ # opencv-python~=4.5.5
8
+ # matplotlib~=3.5.1
9
+
10
+ numpy>=1.21
11
+ torch>=2.0.0
12
+ torchvision>=0.15.0
13
+ pillow>=9.0.0
14
+ tqdm~=4.62.3
15
+ SoccerNet>=0.1.23
16
+ opencv-python~=4.5.5
17
+ matplotlib~=3.5.1
sn_calibration/resources/mean.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8dcd1f96b2486853498b6e38a0eb344cda9c272cd0fe87adb84ec0a09e244a36
3
+ size 152
sn_calibration/resources/std.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49dac1cfa5ab2f197a64332c4fd7dc08335dd46d31244acc0e4a03686241fe79
3
+ size 152
sn_calibration/src/baseline_cameras.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import cv2 as cv
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ from src.camera import Camera
10
+ from src.soccerpitch import SoccerPitch
11
+
12
+
13
+ def normalization_transform(points):
14
+ """
15
+ Computes the similarity transform such that the list of points is centered around (0,0) and that its distance to the
16
+ center is sqrt(2).
17
+ :param points: point cloud that we wish to normalize
18
+ :return: the affine transformation matrix
19
+ """
20
+ center = np.mean(points, axis=0)
21
+
22
+ d = 0.
23
+ nelems = 0
24
+ for p in points:
25
+ nelems += 1
26
+ x = p[0] - center[0]
27
+ y = p[1] - center[1]
28
+ di = np.sqrt(x ** 2 + y ** 2)
29
+ d += (di - d) / nelems
30
+
31
+ if d <= 0.:
32
+ s = 1.
33
+ else:
34
+ s = np.sqrt(2) / d
35
+ T = np.zeros((3, 3))
36
+ T[0, 0] = s
37
+ T[0, 2] = -s * center[0]
38
+ T[1, 1] = s
39
+ T[1, 2] = -s * center[1]
40
+ T[2, 2] = 1
41
+ return T
42
+
43
+
44
+ def estimate_homography_from_line_correspondences(lines, T1=np.eye(3), T2=np.eye(3)):
45
+ """
46
+ Given lines correspondences, computes the homography that maps best the two set of lines.
47
+ :param lines: list of pair of 2D lines matches.
48
+ :param T1: Similarity transform to normalize the elements of the source reference system
49
+ :param T2: Similarity transform to normalize the elements of the target reference system
50
+ :return: boolean to indicate success or failure of the estimation, homography
51
+ """
52
+ homography = np.eye(3)
53
+ A = np.zeros((len(lines) * 2, 9))
54
+
55
+ for i, line_pair in enumerate(lines):
56
+ src_line = np.transpose(np.linalg.inv(T1)) @ line_pair[0]
57
+ target_line = np.transpose(np.linalg.inv(T2)) @ line_pair[1]
58
+ u = src_line[0]
59
+ v = src_line[1]
60
+ w = src_line[2]
61
+
62
+ x = target_line[0]
63
+ y = target_line[1]
64
+ z = target_line[2]
65
+
66
+ A[2 * i, 0] = 0
67
+ A[2 * i, 1] = x * w
68
+ A[2 * i, 2] = -x * v
69
+ A[2 * i, 3] = 0
70
+ A[2 * i, 4] = y * w
71
+ A[2 * i, 5] = -v * y
72
+ A[2 * i, 6] = 0
73
+ A[2 * i, 7] = z * w
74
+ A[2 * i, 8] = -v * z
75
+
76
+ A[2 * i + 1, 0] = x * w
77
+ A[2 * i + 1, 1] = 0
78
+ A[2 * i + 1, 2] = -x * u
79
+ A[2 * i + 1, 3] = y * w
80
+ A[2 * i + 1, 4] = 0
81
+ A[2 * i + 1, 5] = -u * y
82
+ A[2 * i + 1, 6] = z * w
83
+ A[2 * i + 1, 7] = 0
84
+ A[2 * i + 1, 8] = -u * z
85
+
86
+ try:
87
+ u, s, vh = np.linalg.svd(A)
88
+ except np.linalg.LinAlgError:
89
+ return False, homography
90
+ v = np.eye(3)
91
+ has_positive_singular_value = False
92
+ for i in range(s.shape[0] - 1, -2, -1):
93
+ v = np.reshape(vh[i], (3, 3))
94
+
95
+ if s[i] > 0:
96
+ has_positive_singular_value = True
97
+ break
98
+
99
+ if not has_positive_singular_value:
100
+ return False, homography
101
+
102
+ homography = np.reshape(v, (3, 3))
103
+ homography = np.linalg.inv(T2) @ homography @ T1
104
+ homography /= homography[2, 2]
105
+
106
+ return True, homography
107
+
108
+
109
+ def draw_pitch_homography(image, homography):
110
+ """
111
+ Draws points along the soccer pitch markings elements in the image based on the homography projection.
112
+ /!\ This function assumes that the resolution of the image is 540p.
113
+ :param image
114
+ :param homography: homography that captures the relation between the world pitch plane and the image
115
+ :return: modified image
116
+ """
117
+ field = SoccerPitch()
118
+ polylines = field.sample_field_points()
119
+ for line in polylines.values():
120
+
121
+ for point in line:
122
+ if point[2] == 0.:
123
+ hp = np.array((point[0], point[1], 1.))
124
+ projected = homography @ hp
125
+ if projected[2] == 0.:
126
+ continue
127
+ projected /= projected[2]
128
+ if 0 < projected[0] < 960 and 0 < projected[1] < 540:
129
+ cv.circle(image, (int(projected[0]), int(projected[1])), 1, (255, 0, 0), 1)
130
+
131
+ return image
132
+
133
+
134
+ if __name__ == "__main__":
135
+
136
+ parser = argparse.ArgumentParser(description='Baseline for camera parameters extraction')
137
+
138
+ parser.add_argument('-s', '--soccernet', default="/home/fmg/data/SN23/calibration-2023-bis/", type=str,
139
+ help='Path to the SoccerNet-V3 dataset folder')
140
+ parser.add_argument('-p', '--prediction', default="/home/fmg/results/SN23-tests/",
141
+ required=False, type=str,
142
+ help="Path to the prediction folder")
143
+ parser.add_argument('--split', required=False, type=str, default="valid", help='Select the split of data')
144
+ parser.add_argument('--resolution_width', required=False, type=int, default=960,
145
+ help='width resolution of the images')
146
+ parser.add_argument('--resolution_height', required=False, type=int, default=540,
147
+ help='height resolution of the images')
148
+ args = parser.parse_args()
149
+
150
+ field = SoccerPitch()
151
+
152
+ dataset_dir = os.path.join(args.soccernet, args.split)
153
+ if not os.path.exists(dataset_dir):
154
+ print("Invalid dataset path !")
155
+ exit(-1)
156
+
157
+ with open(os.path.join(dataset_dir, "per_match_info.json"), 'r') as f:
158
+ match_info = json.load(f)
159
+
160
+ with tqdm(enumerate(match_info.keys()), total=len(match_info.keys()), ncols=160) as t:
161
+ for i, match in t:
162
+ frame_list = match_info[match].keys()
163
+
164
+ for frame in frame_list:
165
+ frame_index = frame.split(".")[0]
166
+ prediction_file = os.path.join(args.prediction, args.split, f"extremities_{frame_index}.json")
167
+
168
+ if not os.path.exists(prediction_file):
169
+ continue
170
+
171
+ with open(prediction_file, 'r') as f:
172
+ predictions = json.load(f)
173
+
174
+ camera_predictions = dict()
175
+ image_path = os.path.join(dataset_dir, frame)
176
+ # cv_image = cv.imread(image_path)
177
+ # cv_image = cv.resize(cv_image, (args.resolution_width, args.resolution_height))
178
+
179
+ line_matches = []
180
+ potential_3d_2d_matches = {}
181
+ src_pts = []
182
+ success = False
183
+ for k, v in predictions.items():
184
+ if k == 'Circle central' or "unknown" in k:
185
+ continue
186
+ P3D1 = field.line_extremities_keys[k][0]
187
+ P3D2 = field.line_extremities_keys[k][1]
188
+ p1 = np.array([v[0]['x'] * args.resolution_width, v[0]['y'] * args.resolution_height, 1.])
189
+ p2 = np.array([v[1]['x'] * args.resolution_width, v[1]['y'] * args.resolution_height, 1.])
190
+ src_pts.extend([p1, p2])
191
+ if P3D1 in potential_3d_2d_matches.keys():
192
+ potential_3d_2d_matches[P3D1].extend([p1, p2])
193
+ else:
194
+ potential_3d_2d_matches[P3D1] = [p1, p2]
195
+ if P3D2 in potential_3d_2d_matches.keys():
196
+ potential_3d_2d_matches[P3D2].extend([p1, p2])
197
+ else:
198
+ potential_3d_2d_matches[P3D2] = [p1, p2]
199
+
200
+ start = (int(p1[0]), int(p1[1]))
201
+ end = (int(p2[0]), int(p2[1]))
202
+ # cv.line(cv_image, start, end, (0, 0, 255), 1)
203
+
204
+ line = np.cross(p1, p2)
205
+ if np.isnan(np.sum(line)) or np.isinf(np.sum(line)):
206
+ continue
207
+ line_pitch = field.get_2d_homogeneous_line(k)
208
+ if line_pitch is not None:
209
+ line_matches.append((line_pitch, line))
210
+
211
+ if len(line_matches) >= 4:
212
+ target_pts = [field.point_dict[k][:2] for k in potential_3d_2d_matches.keys()]
213
+ T1 = normalization_transform(target_pts)
214
+ T2 = normalization_transform(src_pts)
215
+ success, homography = estimate_homography_from_line_correspondences(line_matches, T1, T2)
216
+ if success:
217
+ # cv_image = draw_pitch_homography(cv_image, homography)
218
+
219
+ cam = Camera(args.resolution_width, args.resolution_height)
220
+ success = cam.from_homography(homography)
221
+ if success:
222
+ point_matches = []
223
+ added_pts = set()
224
+ for k, potential_matches in potential_3d_2d_matches.items():
225
+ p3D = field.point_dict[k]
226
+ projected = cam.project_point(p3D)
227
+
228
+ if 0 < projected[0] < args.resolution_width and 0 < projected[
229
+ 1] < args.resolution_height:
230
+ dist = np.zeros(len(potential_matches))
231
+ for i, potential_match in enumerate(potential_matches):
232
+ dist[i] = np.sqrt((projected[0] - potential_match[0]) ** 2 + (
233
+ projected[1] - potential_match[1]) ** 2)
234
+ selected = np.argmin(dist)
235
+ if dist[selected] < 100:
236
+ point_matches.append((p3D, potential_matches[selected][:2]))
237
+
238
+ if len(point_matches) > 3:
239
+ cam.refine_camera(point_matches)
240
+ # cam.draw_colorful_pitch(cv_image, SoccerField.palette)
241
+ # print(image_path)
242
+ # cv.imshow("colorful pitch", cv_image)
243
+ # cv.waitKey(0)
244
+
245
+ if success:
246
+ camera_predictions = cam.to_json_parameters()
247
+
248
+ task2_prediction_file = os.path.join(args.prediction, args.split, f"camera_{frame_index}.json")
249
+ if camera_predictions:
250
+ with open(task2_prediction_file, "w") as f:
251
+ json.dump(camera_predictions, f, indent=4)
sn_calibration/src/camera.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2 as cv
2
+ import numpy as np
3
+
4
+ from soccerpitch import SoccerPitch
5
+
6
+
7
+ def pan_tilt_roll_to_orientation(pan, tilt, roll):
8
+ """
9
+ Conversion from euler angles to orientation matrix.
10
+ :param pan:
11
+ :param tilt:
12
+ :param roll:
13
+ :return: orientation matrix
14
+ """
15
+ Rpan = np.array([
16
+ [np.cos(pan), -np.sin(pan), 0],
17
+ [np.sin(pan), np.cos(pan), 0],
18
+ [0, 0, 1]])
19
+ Rroll = np.array([
20
+ [np.cos(roll), -np.sin(roll), 0],
21
+ [np.sin(roll), np.cos(roll), 0],
22
+ [0, 0, 1]])
23
+ Rtilt = np.array([
24
+ [1, 0, 0],
25
+ [0, np.cos(tilt), -np.sin(tilt)],
26
+ [0, np.sin(tilt), np.cos(tilt)]])
27
+ rotMat = np.dot(Rpan, np.dot(Rtilt, Rroll))
28
+ return rotMat
29
+
30
+
31
+ def rotation_matrix_to_pan_tilt_roll(rotation):
32
+ """
33
+ Decomposes the rotation matrix into pan, tilt and roll angles. There are two solutions, but as we know that cameramen
34
+ try to minimize roll, we take the solution with the smallest roll.
35
+ :param rotation: rotation matrix
36
+ :return: pan, tilt and roll in radians
37
+ """
38
+ orientation = np.transpose(rotation)
39
+ first_tilt = np.arccos(orientation[2, 2])
40
+ second_tilt = - first_tilt
41
+
42
+ sign_first_tilt = 1. if np.sin(first_tilt) > 0. else -1.
43
+ sign_second_tilt = 1. if np.sin(second_tilt) > 0. else -1.
44
+
45
+ first_pan = np.arctan2(sign_first_tilt * orientation[0, 2], sign_first_tilt * - orientation[1, 2])
46
+ second_pan = np.arctan2(sign_second_tilt * orientation[0, 2], sign_second_tilt * - orientation[1, 2])
47
+ first_roll = np.arctan2(sign_first_tilt * orientation[2, 0], sign_first_tilt * orientation[2, 1])
48
+ second_roll = np.arctan2(sign_second_tilt * orientation[2, 0], sign_second_tilt * orientation[2, 1])
49
+
50
+ # print(f"first solution {first_pan*180./np.pi}, {first_tilt*180./np.pi}, {first_roll*180./np.pi}")
51
+ # print(f"second solution {second_pan*180./np.pi}, {second_tilt*180./np.pi}, {second_roll*180./np.pi}")
52
+ if np.fabs(first_roll) < np.fabs(second_roll):
53
+ return first_pan, first_tilt, first_roll
54
+ return second_pan, second_tilt, second_roll
55
+
56
+
57
+ def unproject_image_point(homography, point2D):
58
+ """
59
+ Given the homography from the world plane of the pitch and the image and a point localized on the pitch plane in the
60
+ image, returns the coordinates of the point in the 3D pitch plane.
61
+ /!\ Only works for correspondences on the pitch (Z = 0).
62
+ :param homography: the homography
63
+ :param point2D: the image point whose relative coordinates on the world plane of the pitch are to be found
64
+ :return: A 2D point on the world pitch plane in homogenous coordinates (X,Y,1) with X and Y being the world
65
+ coordinates of the point.
66
+ """
67
+ hinv = np.linalg.inv(homography)
68
+ pitchpoint = hinv @ point2D
69
+ pitchpoint = pitchpoint / pitchpoint[2]
70
+ return pitchpoint
71
+
72
+
73
+ class Camera:
74
+
75
+ def __init__(self, iwidth=960, iheight=540):
76
+ self.position = np.zeros(3)
77
+ self.rotation = np.eye(3)
78
+ self.calibration = np.eye(3)
79
+ self.radial_distortion = np.zeros(6)
80
+ self.thin_prism_disto = np.zeros(4)
81
+ self.tangential_disto = np.zeros(2)
82
+ self.image_width = iwidth
83
+ self.image_height = iheight
84
+ self.xfocal_length = 1
85
+ self.yfocal_length = 1
86
+ self.principal_point = (self.image_width / 2, self.image_height / 2)
87
+
88
+ def solve_pnp(self, point_matches):
89
+ """
90
+ With a known calibration matrix, this method can be used in order to retrieve rotation and translation camera
91
+ parameters.
92
+ :param point_matches: A list of pairs of 3D-2D point matches .
93
+ """
94
+ target_pts = np.array([pt[0] for pt in point_matches])
95
+ src_pts = np.array([pt[1] for pt in point_matches])
96
+ _, rvec, t, inliers = cv.solvePnPRansac(target_pts, src_pts, self.calibration, None)
97
+ self.rotation, _ = cv.Rodrigues(rvec)
98
+ self.position = - np.transpose(self.rotation) @ t.flatten()
99
+
100
+ def refine_camera(self, pointMatches):
101
+ """
102
+ Once that there is a minimal set of initial camera parameters (calibration, rotation and position roughly known),
103
+ this method can be used to refine the solution using a non-linear optimization procedure.
104
+ :param pointMatches: A list of pairs of 3D-2D point matches .
105
+
106
+ """
107
+ rvec, _ = cv.Rodrigues(self.rotation)
108
+ target_pts = np.array([pt[0] for pt in pointMatches])
109
+ src_pts = np.array([pt[1] for pt in pointMatches])
110
+
111
+ rvec, t = cv.solvePnPRefineLM(target_pts, src_pts, self.calibration, None, rvec, -self.rotation @ self.position,
112
+ (cv.TERM_CRITERIA_MAX_ITER + cv.TERM_CRITERIA_EPS, 20000, 0.00001))
113
+ self.rotation, _ = cv.Rodrigues(rvec)
114
+ self.position = - np.transpose(self.rotation) @ t
115
+
116
+ def from_homography(self, homography):
117
+ """
118
+ This method initializes the essential camera parameters from the homography between the world plane of the pitch
119
+ and the image. It is based on the extraction of the calibration matrix from the homography (Algorithm 8.2 of
120
+ Multiple View Geometry in computer vision, p225), then using the relation between the camera parameters and the
121
+ same homography, we extract rough rotation and position estimates (Example 8.1 of Multiple View Geometry in
122
+ computer vision, p196).
123
+ :param homography: The homography that captures the transformation between the 3D flat model of the soccer pitch
124
+ and its image.
125
+ """
126
+ success, _ = self.estimate_calibration_matrix_from_plane_homography(homography)
127
+ if not success:
128
+ return False
129
+
130
+ hprim = np.linalg.inv(self.calibration) @ homography
131
+ lambda1 = 1 / np.linalg.norm(hprim[:, 0])
132
+ lambda2 = 1 / np.linalg.norm(hprim[:, 1])
133
+ lambda3 = np.sqrt(lambda1 * lambda2)
134
+
135
+ r0 = hprim[:, 0] * lambda1
136
+ r1 = hprim[:, 1] * lambda2
137
+ r2 = np.cross(r0, r1)
138
+
139
+ R = np.column_stack((r0, r1, r2))
140
+ u, s, vh = np.linalg.svd(R)
141
+ R = u @ vh
142
+ if np.linalg.det(R) < 0:
143
+ u[:, 2] *= -1
144
+ R = u @ vh
145
+ self.rotation = R
146
+ t = hprim[:, 2] * lambda3
147
+ self.position = - np.transpose(R) @ t
148
+ return True
149
+
150
+ def to_json_parameters(self):
151
+ """
152
+ Saves camera to a JSON serializable dictionary.
153
+ :return: The dictionary
154
+ """
155
+ pan, tilt, roll = rotation_matrix_to_pan_tilt_roll(self.rotation)
156
+ camera_dict = {
157
+ "pan_degrees": pan * 180. / np.pi,
158
+ "tilt_degrees": tilt * 180. / np.pi,
159
+ "roll_degrees": roll * 180. / np.pi,
160
+ "position_meters": self.position.tolist(),
161
+ "x_focal_length": self.xfocal_length,
162
+ "y_focal_length": self.yfocal_length,
163
+ "principal_point": [self.principal_point[0], self.principal_point[1]],
164
+ "radial_distortion": self.radial_distortion.tolist(),
165
+ "tangential_distortion": self.tangential_disto.tolist(),
166
+ "thin_prism_distortion": self.thin_prism_disto.tolist()
167
+
168
+ }
169
+ return camera_dict
170
+
171
+ def from_json_parameters(self, calib_json_object):
172
+ """
173
+ Loads camera parameters from dictionary.
174
+ :param calib_json_object: the dictionary containing camera parameters.
175
+ """
176
+ self.principal_point = calib_json_object["principal_point"]
177
+ self.image_width = 2 * self.principal_point[0]
178
+ self.image_height = 2 * self.principal_point[1]
179
+ self.xfocal_length = calib_json_object["x_focal_length"]
180
+ self.yfocal_length = calib_json_object["y_focal_length"]
181
+
182
+ self.calibration = np.array([
183
+ [self.xfocal_length, 0, self.principal_point[0]],
184
+ [0, self.yfocal_length, self.principal_point[1]],
185
+ [0, 0, 1]
186
+ ], dtype='float')
187
+
188
+ pan = calib_json_object['pan_degrees'] * np.pi / 180.
189
+ tilt = calib_json_object['tilt_degrees'] * np.pi / 180.
190
+ roll = calib_json_object['roll_degrees'] * np.pi / 180.
191
+
192
+ self.rotation = np.array([
193
+ [-np.sin(pan) * np.sin(roll) * np.cos(tilt) + np.cos(pan) * np.cos(roll),
194
+ np.sin(pan) * np.cos(roll) + np.sin(roll) * np.cos(pan) * np.cos(tilt), np.sin(roll) * np.sin(tilt)],
195
+ [-np.sin(pan) * np.cos(roll) * np.cos(tilt) - np.sin(roll) * np.cos(pan),
196
+ -np.sin(pan) * np.sin(roll) + np.cos(pan) * np.cos(roll) * np.cos(tilt), np.sin(tilt) * np.cos(roll)],
197
+ [np.sin(pan) * np.sin(tilt), -np.sin(tilt) * np.cos(pan), np.cos(tilt)]
198
+ ], dtype='float')
199
+
200
+ self.rotation = np.transpose(pan_tilt_roll_to_orientation(pan, tilt, roll))
201
+
202
+ self.position = np.array(calib_json_object['position_meters'], dtype='float')
203
+
204
+ self.radial_distortion = np.array(calib_json_object['radial_distortion'], dtype='float')
205
+ self.tangential_disto = np.array(calib_json_object['tangential_distortion'], dtype='float')
206
+ self.thin_prism_disto = np.array(calib_json_object['thin_prism_distortion'], dtype='float')
207
+
208
+ def distort(self, point):
209
+ """
210
+ Given a point in the normalized image plane, apply distortion
211
+ :param point: 2D point on the normalized image plane
212
+ :return: 2D distorted point
213
+ """
214
+ numerator = 1
215
+ denominator = 1
216
+ radius = np.sqrt(point[0] * point[0] + point[1] * point[1])
217
+
218
+ for i in range(3):
219
+ k = self.radial_distortion[i]
220
+ numerator += k * radius ** (2 * (i + 1))
221
+ k2n = self.radial_distortion[i + 3]
222
+ denominator += k2n * radius ** (2 * (i + 1))
223
+
224
+ radial_distortion_factor = numerator / denominator
225
+ xpp = point[0] * radial_distortion_factor + \
226
+ 2 * self.tangential_disto[0] * point[0] * point[1] + self.tangential_disto[1] * (
227
+ radius ** 2 + 2 * point[0] ** 2) + \
228
+ self.thin_prism_disto[0] * radius ** 2 + self.thin_prism_disto[1] * radius ** 4
229
+ ypp = point[1] * radial_distortion_factor + \
230
+ 2 * self.tangential_disto[1] * point[0] * point[1] + self.tangential_disto[0] * (
231
+ radius ** 2 + 2 * point[1] ** 2) + \
232
+ self.thin_prism_disto[2] * radius ** 2 + self.thin_prism_disto[3] * radius ** 4
233
+ return np.array([xpp, ypp], dtype=np.float32)
234
+
235
+ def project_point(self, point3D, distort=True):
236
+ """
237
+ Uses current camera parameters to predict where a 3D point is seen by the camera.
238
+ :param point3D: The 3D point in world coordinates.
239
+ :param distort: optional parameter to allow projection without distortion.
240
+ :return: The 2D coordinates of the imaged point
241
+ """
242
+ point = point3D - self.position
243
+ rotated_point = self.rotation @ np.transpose(point)
244
+ if rotated_point[2] <= 1e-3 :
245
+ return np.zeros(3)
246
+ rotated_point = rotated_point / rotated_point[2]
247
+ if distort:
248
+ distorted_point = self.distort(rotated_point)
249
+ else:
250
+ distorted_point = rotated_point
251
+ x = distorted_point[0] * self.xfocal_length + self.principal_point[0]
252
+ y = distorted_point[1] * self.yfocal_length + self.principal_point[1]
253
+ return np.array([x, y, 1])
254
+
255
+ def scale_resolution(self, factor):
256
+ """
257
+ Adapts the internal parameters for image resolution changes
258
+ :param factor: scaling factor
259
+ """
260
+ self.xfocal_length = self.xfocal_length * factor
261
+ self.yfocal_length = self.yfocal_length * factor
262
+ self.image_width = self.image_width * factor
263
+ self.image_height = self.image_height * factor
264
+
265
+ self.principal_point = (self.image_width / 2, self.image_height / 2)
266
+
267
+ self.calibration = np.array([
268
+ [self.xfocal_length, 0, self.principal_point[0]],
269
+ [0, self.yfocal_length, self.principal_point[1]],
270
+ [0, 0, 1]
271
+ ], dtype='float')
272
+
273
+ def draw_corners(self, image, color=(0, 255, 0)):
274
+ """
275
+ Draw the corners of a standard soccer pitch in the image.
276
+ :param image: cv image
277
+ :param color
278
+ :return: the image mat modified.
279
+ """
280
+ field = SoccerPitch()
281
+ for pt3D in field.point_dict.values():
282
+ projected = self.project_point(pt3D)
283
+ if projected[2] == 0.:
284
+ continue
285
+ projected /= projected[2]
286
+ if 0 < projected[0] < self.image_width and 0 < projected[1] < self.image_height:
287
+ cv.circle(image, (int(projected[0]), int(projected[1])), 3, color, 2)
288
+ return image
289
+
290
+ def draw_pitch(self, image, color=(0, 255, 0)):
291
+ """
292
+ Draws all the lines of the pitch on the image.
293
+ :param image
294
+ :param color
295
+ :return: modified image
296
+ """
297
+ field = SoccerPitch()
298
+
299
+ polylines = field.sample_field_points()
300
+ for line in polylines.values():
301
+ prev_point = self.project_point(line[0])
302
+ for point in line[1:]:
303
+ projected = self.project_point(point)
304
+ if projected[2] == 0.:
305
+ continue
306
+ projected /= projected[2]
307
+ if 0 < projected[0] < self.image_width and 0 < projected[1] < self.image_height:
308
+ cv.line(image, (int(prev_point[0]), int(prev_point[1])), (int(projected[0]), int(projected[1])),
309
+ color, 1)
310
+ prev_point = projected
311
+ return image
312
+
313
+ def draw_colorful_pitch(self, image, palette):
314
+ """
315
+ Draws all the lines of the pitch on the image, each line color is specified by the palette argument.
316
+
317
+ :param image:
318
+ :param palette: dictionary associating line classes names with their BGR color.
319
+ :return: modified image
320
+ """
321
+ field = SoccerPitch()
322
+
323
+ polylines = field.sample_field_points()
324
+ for key, line in polylines.items():
325
+ if key not in palette.keys():
326
+ print(f"Can't draw {key}")
327
+ continue
328
+ prev_point = self.project_point(line[0])
329
+ for point in line[1:]:
330
+ projected = self.project_point(point)
331
+ if projected[2] == 0.:
332
+ continue
333
+ projected /= projected[2]
334
+ if 0 < projected[0] < self.image_width and 0 < projected[1] < self.image_height:
335
+ # BGR color
336
+ cv.line(image, (int(prev_point[0]), int(prev_point[1])), (int(projected[0]), int(projected[1])),
337
+ palette[key][::-1], 1)
338
+ prev_point = projected
339
+ return image
340
+
341
+ def estimate_calibration_matrix_from_plane_homography(self, homography):
342
+ """
343
+ This method initializes the calibration matrix from the homography between the world plane of the pitch
344
+ and the image. It is based on the extraction of the calibration matrix from the homography (Algorithm 8.2 of
345
+ Multiple View Geometry in computer vision, p225). The extraction is sensitive to noise, which is why we keep the
346
+ principal point in the middle of the image rather than using the one extracted by this method.
347
+ :param homography: homography between the world plane of the pitch and the image
348
+ """
349
+ H = np.reshape(homography, (9,))
350
+ A = np.zeros((5, 6))
351
+ A[0, 1] = 1.
352
+ A[1, 0] = 1.
353
+ A[1, 2] = -1.
354
+ A[2, 3] = self.principal_point[1] / self.principal_point[0]
355
+ A[2, 4] = -1.0
356
+ A[3, 0] = H[0] * H[1]
357
+ A[3, 1] = H[0] * H[4] + H[1] * H[3]
358
+ A[3, 2] = H[3] * H[4]
359
+ A[3, 3] = H[0] * H[7] + H[1] * H[6]
360
+ A[3, 4] = H[3] * H[7] + H[4] * H[6]
361
+ A[3, 5] = H[6] * H[7]
362
+ A[4, 0] = H[0] * H[0] - H[1] * H[1]
363
+ A[4, 1] = 2 * H[0] * H[3] - 2 * H[1] * H[4]
364
+ A[4, 2] = H[3] * H[3] - H[4] * H[4]
365
+ A[4, 3] = 2 * H[0] * H[6] - 2 * H[1] * H[7]
366
+ A[4, 4] = 2 * H[3] * H[6] - 2 * H[4] * H[7]
367
+ A[4, 5] = H[6] * H[6] - H[7] * H[7]
368
+
369
+ u, s, vh = np.linalg.svd(A)
370
+ w = vh[-1]
371
+ W = np.zeros((3, 3))
372
+ W[0, 0] = w[0] / w[5]
373
+ W[0, 1] = w[1] / w[5]
374
+ W[0, 2] = w[3] / w[5]
375
+ W[1, 0] = w[1] / w[5]
376
+ W[1, 1] = w[2] / w[5]
377
+ W[1, 2] = w[4] / w[5]
378
+ W[2, 0] = w[3] / w[5]
379
+ W[2, 1] = w[4] / w[5]
380
+ W[2, 2] = w[5] / w[5]
381
+
382
+ try:
383
+ Ktinv = np.linalg.cholesky(W)
384
+ except np.linalg.LinAlgError:
385
+ K = np.eye(3)
386
+ return False, K
387
+
388
+ K = np.linalg.inv(np.transpose(Ktinv))
389
+ K /= K[2, 2]
390
+
391
+ self.xfocal_length = K[0, 0]
392
+ self.yfocal_length = K[1, 1]
393
+ # the principal point estimated by this method is very noisy, better keep it in the center of the image
394
+ self.principal_point = (self.image_width / 2, self.image_height / 2)
395
+ # self.principal_point = (K[0,2], K[1,2])
396
+ self.calibration = np.array([
397
+ [self.xfocal_length, 0, self.principal_point[0]],
398
+ [0, self.yfocal_length, self.principal_point[1]],
399
+ [0, 0, 1]
400
+ ], dtype='float')
401
+ return True, K
402
+
sn_calibration/src/dataloader.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DataLoader used to train the segmentation network used for the prediction of extremities.
3
+ """
4
+
5
+ import json
6
+ import os
7
+ import time
8
+ from argparse import ArgumentParser
9
+
10
+ import cv2 as cv
11
+ import numpy as np
12
+ from torch.utils.data import Dataset
13
+ from tqdm import tqdm
14
+
15
+ from src.soccerpitch import SoccerPitch
16
+
17
+
18
+ class SoccerNetDataset(Dataset):
19
+ def __init__(self,
20
+ datasetpath,
21
+ split="test",
22
+ width=640,
23
+ height=360,
24
+ mean="../resources/mean.npy",
25
+ std="../resources/std.npy"):
26
+ self.mean = np.load(mean)
27
+ self.std = np.load(std)
28
+ self.width = width
29
+ self.height = height
30
+
31
+ dataset_dir = os.path.join(datasetpath, split)
32
+ if not os.path.exists(dataset_dir):
33
+ print("Invalid dataset path !")
34
+ exit(-1)
35
+
36
+ frames = [f for f in os.listdir(dataset_dir) if ".jpg" in f]
37
+
38
+ self.data = []
39
+ self.n_samples = 0
40
+ for frame in frames:
41
+
42
+ frame_index = frame.split(".")[0]
43
+ annotation_file = os.path.join(dataset_dir, f"{frame_index}.json")
44
+ if not os.path.exists(annotation_file):
45
+ continue
46
+ with open(annotation_file, "r") as f:
47
+ groundtruth_lines = json.load(f)
48
+ img_path = os.path.join(dataset_dir, frame)
49
+ if groundtruth_lines:
50
+ self.data.append({
51
+ "image_path": img_path,
52
+ "annotations": groundtruth_lines,
53
+ })
54
+
55
+ def __len__(self):
56
+ return len(self.data)
57
+
58
+ def __getitem__(self, index):
59
+ item = self.data[index]
60
+
61
+ img = cv.imread(item["image_path"])
62
+ img = cv.resize(img, (self.width, self.height), interpolation=cv.INTER_LINEAR)
63
+
64
+ mask = np.zeros(img.shape[:-1], dtype=np.uint8)
65
+ img = np.asarray(img, np.float32) / 255.
66
+ img -= self.mean
67
+ img /= self.std
68
+ img = img.transpose((2, 0, 1))
69
+ for class_number, class_ in enumerate(SoccerPitch.lines_classes):
70
+ if class_ in item["annotations"].keys():
71
+ key = class_
72
+ line = item["annotations"][key]
73
+ prev_point = line[0]
74
+ for i in range(1, len(line)):
75
+ next_point = line[i]
76
+ cv.line(mask,
77
+ (int(prev_point["x"] * mask.shape[1]), int(prev_point["y"] * mask.shape[0])),
78
+ (int(next_point["x"] * mask.shape[1]), int(next_point["y"] * mask.shape[0])),
79
+ class_number + 1,
80
+ 2)
81
+ prev_point = next_point
82
+ return img, mask
83
+
84
+
85
+ if __name__ == "__main__":
86
+
87
+ # Load the arguments
88
+ parser = ArgumentParser(description='dataloader')
89
+
90
+ parser.add_argument('--SoccerNet_path', default="./annotations/", type=str,
91
+ help='Path to the SoccerNet-V3 dataset folder')
92
+ parser.add_argument('--tiny', required=False, type=int, default=None, help='Select a subset of x games')
93
+ parser.add_argument('--split', required=False, type=str, default="test", help='Select the split of data')
94
+ parser.add_argument('--num_workers', required=False, type=int, default=4,
95
+ help='number of workers for the dataloader')
96
+ parser.add_argument('--resolution_width', required=False, type=int, default=1920,
97
+ help='width resolution of the images')
98
+ parser.add_argument('--resolution_height', required=False, type=int, default=1080,
99
+ help='height resolution of the images')
100
+ parser.add_argument('--preload_images', action='store_true',
101
+ help="Preload the images when constructing the dataset")
102
+ parser.add_argument('--zipped_images', action='store_true', help="Read images from zipped folder")
103
+
104
+ args = parser.parse_args()
105
+
106
+ start_time = time.time()
107
+ soccernet = SoccerNetDataset(args.SoccerNet_path, split=args.split)
108
+ with tqdm(enumerate(soccernet), total=len(soccernet), ncols=160) as t:
109
+ for i, data in t:
110
+ img = soccernet[i][0].astype(np.uint8).transpose((1, 2, 0))
111
+ print(img.shape)
112
+ print(img.dtype)
113
+ cv.imshow("Normalized image", img)
114
+ cv.waitKey(0)
115
+ cv.destroyAllWindows()
116
+ print(data[1].shape)
117
+ cv.imshow("Mask", soccernet[i][1].astype(np.uint8))
118
+ cv.waitKey(0)
119
+ cv.destroyAllWindows()
120
+ continue
121
+ end_time = time.time()
122
+ print(end_time - start_time)
sn_calibration/src/detect_extremities.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import json
4
+ import os.path
5
+ import random
6
+ from collections import deque
7
+ from pathlib import Path
8
+
9
+ import cv2 as cv
10
+ import numpy as np
11
+ import torch
12
+ import torch.backends.cudnn
13
+ import torch.nn as nn
14
+ from PIL import Image
15
+ from torchvision.models.segmentation import deeplabv3_resnet50
16
+ from tqdm import tqdm
17
+
18
+ from soccerpitch import SoccerPitch
19
+
20
+
21
+ def generate_class_synthesis(semantic_mask, radius):
22
+ """
23
+ This function selects for each class present in the semantic mask, a set of circles that cover most of the semantic
24
+ class blobs.
25
+ :param semantic_mask: a image containing the segmentation predictions
26
+ :param radius: circle radius
27
+ :return: a dictionary which associates with each class detected a list of points ( the circles centers)
28
+ """
29
+ buckets = dict()
30
+ kernel = np.ones((5, 5), np.uint8)
31
+ semantic_mask = cv.erode(semantic_mask, kernel, iterations=1)
32
+ for k, class_name in enumerate(SoccerPitch.lines_classes):
33
+ mask = semantic_mask == k + 1
34
+ if mask.sum() > 0:
35
+ disk_list = synthesize_mask(mask, radius)
36
+ if len(disk_list):
37
+ buckets[class_name] = disk_list
38
+
39
+ return buckets
40
+
41
+
42
+ def join_points(point_list, maxdist):
43
+ """
44
+ Given a list of points that were extracted from the blobs belonging to a same semantic class, this function creates
45
+ polylines by linking close points together if their distance is below the maxdist threshold.
46
+ :param point_list: List of points of the same line class
47
+ :param maxdist: minimal distance between two polylines.
48
+ :return: a list of polylines
49
+ """
50
+ polylines = []
51
+
52
+ if not len(point_list):
53
+ return polylines
54
+ head = point_list[0]
55
+ tail = point_list[0]
56
+ polyline = deque()
57
+ polyline.append(point_list[0])
58
+ remaining_points = copy.deepcopy(point_list[1:])
59
+
60
+ while len(remaining_points) > 0:
61
+ min_dist_tail = 1000
62
+ min_dist_head = 1000
63
+ best_head = -1
64
+ best_tail = -1
65
+ for j, point in enumerate(remaining_points):
66
+ dist_tail = np.sqrt(np.sum(np.square(point - tail)))
67
+ dist_head = np.sqrt(np.sum(np.square(point - head)))
68
+ if dist_tail < min_dist_tail:
69
+ min_dist_tail = dist_tail
70
+ best_tail = j
71
+ if dist_head < min_dist_head:
72
+ min_dist_head = dist_head
73
+ best_head = j
74
+
75
+ if min_dist_head <= min_dist_tail and min_dist_head < maxdist:
76
+ polyline.appendleft(remaining_points[best_head])
77
+ head = polyline[0]
78
+ remaining_points.pop(best_head)
79
+ elif min_dist_tail < min_dist_head and min_dist_tail < maxdist:
80
+ polyline.append(remaining_points[best_tail])
81
+ tail = polyline[-1]
82
+ remaining_points.pop(best_tail)
83
+ else:
84
+ polylines.append(list(polyline.copy()))
85
+ head = remaining_points[0]
86
+ tail = remaining_points[0]
87
+ polyline = deque()
88
+ polyline.append(head)
89
+ remaining_points.pop(0)
90
+ polylines.append(list(polyline))
91
+ return polylines
92
+
93
+
94
+ def get_line_extremities(buckets, maxdist, width, height):
95
+ """
96
+ Given the dictionary {lines_class: points}, finds plausible extremities of each line, i.e the extremities
97
+ of the longest polyline that can be built on the class blobs, and normalize its coordinates
98
+ by the image size.
99
+ :param buckets: The dictionary associating line classes to the set of circle centers that covers best the class
100
+ prediction blobs in the segmentation mask
101
+ :param maxdist: the maximal distance between two circle centers belonging to the same blob (heuristic)
102
+ :param width: image width
103
+ :param height: image height
104
+ :return: a dictionary associating to each class its extremities
105
+ """
106
+ extremities = dict()
107
+ for class_name, disks_list in buckets.items():
108
+ polyline_list = join_points(disks_list, maxdist)
109
+ max_len = 0
110
+ longest_polyline = []
111
+ for polyline in polyline_list:
112
+ if len(polyline) > max_len:
113
+ max_len = len(polyline)
114
+ longest_polyline = polyline
115
+ extremities[class_name] = [
116
+ {'x': longest_polyline[0][1] / width, 'y': longest_polyline[0][0] / height},
117
+ {'x': longest_polyline[-1][1] / width, 'y': longest_polyline[-1][0] / height}
118
+ ]
119
+ return extremities
120
+
121
+
122
+ def get_support_center(mask, start, disk_radius, min_support=0.1):
123
+ """
124
+ Returns the barycenter of the True pixels under the area of the mask delimited by the circle of center start and
125
+ radius of disk_radius pixels.
126
+ :param mask: Boolean mask
127
+ :param start: A point located on a true pixel of the mask
128
+ :param disk_radius: the radius of the circles
129
+ :param min_support: proportion of the area under the circle area that should be True in order to get enough support
130
+ :return: A boolean indicating if there is enough support in the circle area, the barycenter of the True pixels under
131
+ the circle
132
+ """
133
+ x = int(start[0])
134
+ y = int(start[1])
135
+ support_pixels = 1
136
+ result = [x, y]
137
+ xstart = x - disk_radius
138
+ if xstart < 0:
139
+ xstart = 0
140
+ xend = x + disk_radius
141
+ if xend > mask.shape[0]:
142
+ xend = mask.shape[0] - 1
143
+
144
+ ystart = y - disk_radius
145
+ if ystart < 0:
146
+ ystart = 0
147
+ yend = y + disk_radius
148
+ if yend > mask.shape[1]:
149
+ yend = mask.shape[1] - 1
150
+
151
+ for i in range(xstart, xend + 1):
152
+ for j in range(ystart, yend + 1):
153
+ dist = np.sqrt(np.square(x - i) + np.square(y - j))
154
+ if dist < disk_radius and mask[i, j] > 0:
155
+ support_pixels += 1
156
+ result[0] += i
157
+ result[1] += j
158
+ support = True
159
+ if support_pixels < min_support * np.square(disk_radius) * np.pi:
160
+ support = False
161
+
162
+ result = np.array(result)
163
+ result = np.true_divide(result, support_pixels)
164
+
165
+ return support, result
166
+
167
+
168
+ def synthesize_mask(semantic_mask, disk_radius):
169
+ """
170
+ Fits circles on the True pixels of the mask and returns those which have enough support : meaning that the
171
+ proportion of the area of the circle covering True pixels is higher that a certain threshold in order to avoid
172
+ fitting circles on alone pixels.
173
+ :param semantic_mask: boolean mask
174
+ :param disk_radius: radius of the circles
175
+ :return: a list of disk centers, that have enough support
176
+ """
177
+ mask = semantic_mask.copy().astype(np.uint8)
178
+ points = np.transpose(np.nonzero(mask))
179
+ disks = []
180
+ while len(points):
181
+
182
+ start = random.choice(points)
183
+ dist = 10.
184
+ success = True
185
+ while dist > 1.:
186
+ enough_support, center = get_support_center(mask, start, disk_radius)
187
+ if not enough_support:
188
+ bad_point = np.round(center).astype(np.int32)
189
+ cv.circle(mask, (bad_point[1], bad_point[0]), disk_radius, (0), -1)
190
+ success = False
191
+ dist = np.sqrt(np.sum(np.square(center - start)))
192
+ start = center
193
+ if success:
194
+ disks.append(np.round(start).astype(np.int32))
195
+ cv.circle(mask, (disks[-1][1], disks[-1][0]), disk_radius, 0, -1)
196
+ points = np.transpose(np.nonzero(mask))
197
+
198
+ return disks
199
+
200
+
201
+ class SegmentationNetwork:
202
+ def __init__(self, model_file, mean_file, std_file, num_classes=29, width=640, height=360):
203
+ file_path = Path(model_file).resolve()
204
+ model = nn.DataParallel(deeplabv3_resnet50(pretrained=False, num_classes=num_classes))
205
+ self.init_weight(model, nn.init.kaiming_normal_,
206
+ nn.BatchNorm2d, 1e-3, 0.1,
207
+ mode='fan_in')
208
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
209
+ checkpoint = torch.load(str(file_path), map_location=self.device)
210
+ model.load_state_dict(checkpoint["model"])
211
+ model.eval()
212
+ self.model = model.to(self.device)
213
+ file_path = Path(mean_file).resolve()
214
+ self.mean = np.load(str(file_path))
215
+ file_path = Path(std_file).resolve()
216
+ self.std = np.load(str(file_path))
217
+ self.width = width
218
+ self.height = height
219
+
220
+ def init_weight(self, feature, conv_init, norm_layer, bn_eps, bn_momentum,
221
+ **kwargs):
222
+ for name, m in feature.named_modules():
223
+ if isinstance(m, (nn.Conv2d, nn.Conv3d)):
224
+ conv_init(m.weight, **kwargs)
225
+ elif isinstance(m, norm_layer):
226
+ m.eps = bn_eps
227
+ m.momentum = bn_momentum
228
+ nn.init.constant_(m.weight, 1)
229
+ nn.init.constant_(m.bias, 0)
230
+
231
+ def analyse_image(self, image):
232
+ """
233
+ Process image and perform inference, returns mask of detected classes
234
+ :param image: BGR image
235
+ :return: predicted classes mask
236
+ """
237
+ img = cv.resize(image, (self.width, self.height), interpolation=cv.INTER_LINEAR)
238
+ img = np.asarray(img, np.float32) / 255.
239
+ img = (img - self.mean) / self.std
240
+ img = img.transpose((2, 0, 1))
241
+ img = torch.from_numpy(img).to(self.device).unsqueeze(0)
242
+
243
+ cuda_result = self.model.forward(img.float())
244
+ output = cuda_result['out'].data[0].cpu().numpy()
245
+ output = output.transpose(1, 2, 0)
246
+ output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
247
+
248
+ return output
249
+
250
+
251
+ if __name__ == "__main__":
252
+ parser = argparse.ArgumentParser(description='Test')
253
+
254
+ parser.add_argument('-s', '--soccernet', default="/home/fmg/data/SN23/calibration-2023-bis/", type=str,
255
+ help='Path to the SoccerNet-V3 dataset folder')
256
+ parser.add_argument('-p', '--prediction', default="/home/fmg/results/SN23-tests/", required=False, type=str,
257
+ help="Path to the prediction folder")
258
+ parser.add_argument('--split', required=False, type=str, default="challenge", help='Select the split of data')
259
+ parser.add_argument('--masks', required=False, type=bool, default=False, help='Save masks in prediction directory')
260
+ parser.add_argument('--resolution_width', required=False, type=int, default=640,
261
+ help='width resolution of the images')
262
+ parser.add_argument('--resolution_height', required=False, type=int, default=360,
263
+ help='height resolution of the images')
264
+ args = parser.parse_args()
265
+
266
+ lines_palette = [0, 0, 0]
267
+ for line_class in SoccerPitch.lines_classes:
268
+ lines_palette.extend(SoccerPitch.palette[line_class])
269
+
270
+ calib_net = SegmentationNetwork(
271
+ "../resources/soccer_pitch_segmentation.pth",
272
+ "../resources/mean.npy",
273
+ "../resources/std.npy")
274
+
275
+ dataset_dir = os.path.join(args.soccernet, args.split)
276
+ if not os.path.exists(dataset_dir):
277
+ print("Invalid dataset path !")
278
+ exit(-1)
279
+
280
+ with open(os.path.join(dataset_dir, "per_match_info.json"), 'r') as f:
281
+ match_info = json.load(f)
282
+
283
+ with tqdm(enumerate(match_info.keys()), total=len(match_info.keys()), ncols=160) as t:
284
+ for i, match in t:
285
+ frame_list = match_info[match].keys()
286
+
287
+ for frame in frame_list:
288
+
289
+ output_prediction_folder = os.path.join(args.prediction, args.split)
290
+ if not os.path.exists(output_prediction_folder):
291
+ os.makedirs(output_prediction_folder)
292
+ prediction = dict()
293
+ count = 0
294
+
295
+ frame_path = os.path.join(dataset_dir, frame)
296
+
297
+ frame_index = frame.split(".")[0]
298
+
299
+ image = cv.imread(frame_path)
300
+ semlines = calib_net.analyse_image(image)
301
+ if args.masks:
302
+ mask = Image.fromarray(semlines.astype(np.uint8)).convert('P')
303
+ mask.putpalette(lines_palette)
304
+ mask_file = os.path.join(output_prediction_folder, frame)
305
+ mask.save(mask_file)
306
+ skeletons = generate_class_synthesis(semlines, 6)
307
+ extremities = get_line_extremities(skeletons, 40, args.resolution_width, args.resolution_height)
308
+
309
+ prediction = extremities
310
+ count += 1
311
+
312
+ prediction_file = os.path.join(output_prediction_folder, f"extremities_{frame_index}.json")
313
+ with open(prediction_file, "w") as f:
314
+ json.dump(prediction, f, indent=4)
sn_calibration/src/evalai_camera.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import zipfile
3
+ import argparse
4
+ import json
5
+ import sys
6
+
7
+ from tqdm import tqdm
8
+
9
+ sys.path.append("sn_calibration")
10
+ sys.path.append("sn_calibration/src")
11
+
12
+ from evaluate_camera import get_polylines, scale_points, evaluate_camera_prediction
13
+ from evaluate_extremities import mirror_labels
14
+
15
+
16
+
17
+ def evaluate(gt_zip, prediction_zip, width=960, height=540):
18
+ gt_archive = zipfile.ZipFile(gt_zip, 'r')
19
+ prediction_archive = zipfile.ZipFile(prediction_zip, 'r')
20
+ gt_jsons = gt_archive.namelist()
21
+
22
+ accuracies = []
23
+ precisions = []
24
+ recalls = []
25
+ dict_errors = {}
26
+ per_class_confusion_dict = {}
27
+ total_frames = 0
28
+ missed = 0
29
+ for gt_json in tqdm(gt_jsons):
30
+
31
+ #split, name = gt_json.split("/")
32
+ #pred_name = f"{split}/camera_{name}"
33
+ pred_name = f"camera_{gt_json}"
34
+
35
+ total_frames += 1
36
+
37
+ if pred_name not in prediction_archive.namelist():
38
+ missed += 1
39
+ continue
40
+
41
+ prediction = prediction_archive.read(pred_name)
42
+ prediction = json.loads(prediction.decode("utf-8"))
43
+ gt = gt_archive.read(gt_json)
44
+ gt = json.loads(gt.decode('utf-8'))
45
+
46
+ line_annotations = scale_points(gt, width, height)
47
+
48
+ img_groundtruth = line_annotations
49
+
50
+ img_prediction = get_polylines(prediction, width, height,
51
+ sampling_factor=0.9)
52
+
53
+ confusion1, per_class_conf1, reproj_errors1 = evaluate_camera_prediction(img_prediction,
54
+ img_groundtruth,
55
+ 5)
56
+
57
+ confusion2, per_class_conf2, reproj_errors2 = evaluate_camera_prediction(img_prediction,
58
+ mirror_labels(img_groundtruth),
59
+ 5)
60
+
61
+ accuracy1, accuracy2 = 0., 0.
62
+ if confusion1.sum() > 0:
63
+ accuracy1 = confusion1[0, 0] / confusion1.sum()
64
+
65
+ if confusion2.sum() > 0:
66
+ accuracy2 = confusion2[0, 0] / confusion2.sum()
67
+
68
+ if accuracy1 > accuracy2:
69
+ accuracy = accuracy1
70
+ confusion = confusion1
71
+ per_class_conf = per_class_conf1
72
+ reproj_errors = reproj_errors1
73
+ else:
74
+ accuracy = accuracy2
75
+ confusion = confusion2
76
+ per_class_conf = per_class_conf2
77
+ reproj_errors = reproj_errors2
78
+
79
+ accuracies.append(accuracy)
80
+ if confusion[0, :].sum() > 0:
81
+ precision = confusion[0, 0] / (confusion[0, :].sum())
82
+ precisions.append(precision)
83
+ if (confusion[0, 0] + confusion[1, 0]) > 0:
84
+ recall = confusion[0, 0] / (confusion[0, 0] + confusion[1, 0])
85
+ recalls.append(recall)
86
+
87
+ for line_class, errors in reproj_errors.items():
88
+ if line_class in dict_errors.keys():
89
+ dict_errors[line_class].extend(errors)
90
+ else:
91
+ dict_errors[line_class] = errors
92
+
93
+ for line_class, confusion_mat in per_class_conf.items():
94
+ if line_class in per_class_confusion_dict.keys():
95
+ per_class_confusion_dict[line_class] += confusion_mat
96
+ else:
97
+ per_class_confusion_dict[line_class] = confusion_mat
98
+
99
+
100
+ results = {}
101
+ results["completeness"] = (total_frames - missed) / total_frames
102
+ results["meanRecall"] = np.mean(recalls)
103
+ results["meanPrecision"] = np.mean(precisions)
104
+ results["meanAccuracies"] = np.mean(accuracies)
105
+ results["finalScore"] = results["completeness"] * results["meanAccuracies"]
106
+
107
+
108
+ for line_class, confusion_mat in per_class_confusion_dict.items():
109
+ class_accuracy = confusion_mat[0, 0] / confusion_mat.sum()
110
+ class_recall = confusion_mat[0, 0] / (confusion_mat[0, 0] + confusion_mat[1, 0])
111
+ class_precision = confusion_mat[0, 0] / (confusion_mat[0, 0] + confusion_mat[0, 1])
112
+ results[f"{line_class}Precision"] = class_precision
113
+ results[f"{line_class}Recall"] = class_recall
114
+ results[f"{line_class}Accuracy"] = class_accuracy
115
+ return results
116
+
117
+
118
+ if __name__ == "__main__":
119
+ parser = argparse.ArgumentParser(description='Evaluation camera calibration task')
120
+
121
+ parser.add_argument('-s', '--soccernet', default="/home/fmg/data/SN23/calibration-2023/test_secret.zip", type=str,
122
+ help='Path to the zip groundtruth folder')
123
+ parser.add_argument('-p', '--prediction', default="/home/fmg/results/SN23-tests/test.zip",
124
+ required=False, type=str,
125
+ help="Path to the zip prediction folder")
126
+ parser.add_argument('--width', type=int, default=960)
127
+ parser.add_argument('--height', type=int, default=540)
128
+
129
+ args = parser.parse_args()
130
+
131
+ results = evaluate(args.soccernet, args.prediction, args.width, args.height)
132
+ for key in results.keys():
133
+ print(f"{key}: {results[key]}")
sn_calibration/src/evaluate_camera.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import sys
4
+ import os
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+
10
+ from camera import Camera
11
+ from evaluate_extremities import scale_points, distance, mirror_labels
12
+ from soccerpitch import SoccerPitch
13
+
14
+
15
+ def get_polylines(camera_annotation, width, height, sampling_factor=0.2):
16
+ """
17
+ Given a set of camera parameters, this function adapts the camera to the desired image resolution and then
18
+ projects the 3D points belonging to the terrain model in order to give a dictionary associating the classes
19
+ observed and the points projected in the image.
20
+
21
+ :param camera_annotation: camera parameters in their json/dictionary format
22
+ :param width: image width for evaluation
23
+ :param height: image height for evaluation
24
+ :return: a dictionary with keys corresponding to a class observed in the image ( a line of the 3D model whose
25
+ projection falls in the image) and values are then the list of 2D projected points.
26
+ """
27
+
28
+ cam = Camera(width, height)
29
+ cam.from_json_parameters(camera_annotation)
30
+ field = SoccerPitch()
31
+ projections = dict()
32
+ sides = [
33
+ np.array([1, 0, 0]),
34
+ np.array([1, 0, -width + 1]),
35
+ np.array([0, 1, 0]),
36
+ np.array([0, 1, -height + 1])
37
+ ]
38
+ for key, points in field.sample_field_points(sampling_factor).items():
39
+ projections_list = []
40
+ in_img = False
41
+ prev_proj = np.zeros(3)
42
+ for i, point in enumerate(points):
43
+ ext = cam.project_point(point)
44
+ if ext[2] < 1e-5:
45
+ # point at infinity or behind camera
46
+ continue
47
+ if 0 <= ext[0] < width and 0 <= ext[1] < height:
48
+
49
+ if not in_img and i > 0:
50
+
51
+ line = np.cross(ext, prev_proj)
52
+ in_img_intersections = []
53
+ dist_to_ext = []
54
+ for side in sides:
55
+ intersection = np.cross(line, side)
56
+ intersection /= intersection[2]
57
+ if 0 <= intersection[0] < width and 0 <= intersection[1] < height:
58
+ in_img_intersections.append(intersection)
59
+ dist_to_ext.append(np.sqrt(np.sum(np.square(intersection - ext))))
60
+ if in_img_intersections:
61
+ intersection = in_img_intersections[np.argmin(dist_to_ext)]
62
+
63
+ projections_list.append(
64
+ {
65
+ "x": intersection[0],
66
+ "y": intersection[1]
67
+ }
68
+ )
69
+
70
+ projections_list.append(
71
+ {
72
+ "x": ext[0],
73
+ "y": ext[1]
74
+ }
75
+ )
76
+ in_img = True
77
+ elif in_img:
78
+ # first point out
79
+ line = np.cross(ext, prev_proj)
80
+
81
+ in_img_intersections = []
82
+ dist_to_ext = []
83
+ for side in sides:
84
+ intersection = np.cross(line, side)
85
+ intersection /= intersection[2]
86
+ if 0 <= intersection[0] < width and 0 <= intersection[1] < height:
87
+ in_img_intersections.append(intersection)
88
+ dist_to_ext.append(np.sqrt(np.sum(np.square(intersection - ext))))
89
+ if in_img_intersections:
90
+ intersection = in_img_intersections[np.argmin(dist_to_ext)]
91
+
92
+ projections_list.append(
93
+ {
94
+ "x": intersection[0],
95
+ "y": intersection[1]
96
+ }
97
+ )
98
+ in_img = False
99
+ prev_proj = ext
100
+ if len(projections_list):
101
+ projections[key] = projections_list
102
+ return projections
103
+
104
+
105
+ def distance_to_polyline(point, polyline):
106
+ """
107
+ Computes euclidian distance between a point and a polyline.
108
+ :param point: 2D point
109
+ :param polyline: a list of 2D point
110
+ :return: the distance value
111
+ """
112
+ if 0 < len(polyline) < 2:
113
+ dist = distance(point, polyline[0])
114
+ return dist
115
+ else:
116
+ dist_to_segments = []
117
+ point_np = np.array([point["x"], point["y"], 1])
118
+
119
+ for i in range(len(polyline) - 1):
120
+ origin_segment = np.array([
121
+ polyline[i]["x"],
122
+ polyline[i]["y"],
123
+ 1
124
+ ])
125
+ end_segment = np.array([
126
+ polyline[i + 1]["x"],
127
+ polyline[i + 1]["y"],
128
+ 1
129
+ ])
130
+ line = np.cross(origin_segment, end_segment)
131
+ line /= np.sqrt(np.square(line[0]) + np.square(line[1]))
132
+
133
+ # project point on line l
134
+ projected = np.cross((np.cross(np.array([line[0], line[1], 0]), point_np)), line)
135
+ projected = projected / projected[2]
136
+
137
+ v1 = projected - origin_segment
138
+ v2 = end_segment - origin_segment
139
+ k = np.dot(v1, v2) / np.dot(v2, v2)
140
+ if 0 < k < 1:
141
+
142
+ segment_distance = np.sqrt(np.sum(np.square(projected - point_np)))
143
+ else:
144
+ d1 = distance(point, polyline[i])
145
+ d2 = distance(point, polyline[i + 1])
146
+ segment_distance = np.min([d1, d2])
147
+
148
+ dist_to_segments.append(segment_distance)
149
+ return np.min(dist_to_segments)
150
+
151
+
152
+ def evaluate_camera_prediction(projected_lines, groundtruth_lines, threshold):
153
+ """
154
+ Computes confusion matrices for a level of precision specified by the threshold.
155
+ A groundtruth line is correctly classified if it lies at less than threshold pixels from a line of the prediction
156
+ of the same class.
157
+ Computes also the reprojection error of each groundtruth point : the reprojection error is the L2 distance between
158
+ the point and the projection of the line.
159
+ :param projected_lines: dictionary of detected lines classes as keys and associated predicted points as values
160
+ :param groundtruth_lines: dictionary of annotated lines classes as keys and associated annotated points as values
161
+ :param threshold: distance in pixels that distinguishes good matches from bad ones
162
+ :return: confusion matrix, per class confusion matrix & per class reprojection errors
163
+ """
164
+ global_confusion_mat = np.zeros((2, 2), dtype=np.float32)
165
+ per_class_confusion = {}
166
+ dict_errors = {}
167
+ detected_classes = set(projected_lines.keys())
168
+ groundtruth_classes = set(groundtruth_lines.keys())
169
+
170
+ false_positives_classes = detected_classes - groundtruth_classes
171
+ for false_positive_class in false_positives_classes:
172
+ # false_positives = len(projected_lines[false_positive_class])
173
+ if "Circle" not in false_positive_class:
174
+ # Count only extremities for lines, independently of soccer pitch sampling
175
+ false_positives = 2.
176
+ else:
177
+ false_positives = 9.
178
+ per_class_confusion[false_positive_class] = np.array([[0., false_positives], [0., 0.]])
179
+ global_confusion_mat[0, 1] += 1
180
+
181
+ false_negatives_classes = groundtruth_classes - detected_classes
182
+ for false_negatives_class in false_negatives_classes:
183
+ false_negatives = len(groundtruth_lines[false_negatives_class])
184
+ per_class_confusion[false_negatives_class] = np.array([[0., 0.], [false_negatives, 0.]])
185
+ global_confusion_mat[1, 0] += 1
186
+
187
+ common_classes = detected_classes - false_positives_classes
188
+
189
+ for detected_class in common_classes:
190
+
191
+ detected_points = projected_lines[detected_class]
192
+ groundtruth_points = groundtruth_lines[detected_class]
193
+
194
+ per_class_confusion[detected_class] = np.zeros((2, 2))
195
+
196
+ all_below_dist = 1
197
+ for point in groundtruth_points:
198
+
199
+ dist_to_poly = distance_to_polyline(point, detected_points)
200
+ if dist_to_poly < threshold:
201
+ per_class_confusion[detected_class][0, 0] += 1
202
+ else:
203
+ per_class_confusion[detected_class][0, 1] += 1
204
+ all_below_dist *= 0
205
+
206
+ if detected_class in dict_errors.keys():
207
+ dict_errors[detected_class].append(dist_to_poly)
208
+ else:
209
+ dict_errors[detected_class] = [dist_to_poly]
210
+
211
+ if all_below_dist:
212
+ global_confusion_mat[0, 0] += 1
213
+ else:
214
+ global_confusion_mat[0, 1] += 1
215
+
216
+ return global_confusion_mat, per_class_confusion, dict_errors
217
+
218
+
219
+ if __name__ == "__main__":
220
+
221
+ parser = argparse.ArgumentParser(description='Evaluation camera calibration task')
222
+
223
+ parser.add_argument('-s', '--soccernet', default="/home/fmg/data/SN23/calibration-2023-bis/", type=str,
224
+ help='Path to the SoccerNet-V3 dataset folder')
225
+ parser.add_argument('-p', '--prediction', default="/home/fmg/results/SN23-tests/",
226
+ required=False, type=str,
227
+ help="Path to the prediction folder")
228
+ parser.add_argument('-t', '--threshold', default=5, required=False, type=int,
229
+ help="Accuracy threshold in pixels")
230
+ parser.add_argument('--split', required=False, type=str, default="valid", help='Select the split of data')
231
+ parser.add_argument('--resolution_width', required=False, type=int, default=960,
232
+ help='width resolution of the images')
233
+ parser.add_argument('--resolution_height', required=False, type=int, default=540,
234
+ help='height resolution of the images')
235
+ args = parser.parse_args()
236
+
237
+ accuracies = []
238
+ precisions = []
239
+ recalls = []
240
+ dict_errors = {}
241
+ per_class_confusion_dict = {}
242
+
243
+ dataset_dir = os.path.join(args.soccernet, args.split)
244
+ if not os.path.exists(dataset_dir):
245
+ print("Invalid dataset path !")
246
+ exit(-1)
247
+
248
+ annotation_files = [f for f in os.listdir(dataset_dir) if ".json" in f]
249
+
250
+ missed, total_frames = 0, 0
251
+ with tqdm(enumerate(annotation_files), total=len(annotation_files), ncols=160) as t:
252
+ for i, annotation_file in t:
253
+ frame_index = annotation_file.split(".")[0]
254
+ annotation_file = os.path.join(args.soccernet, args.split, annotation_file)
255
+ prediction_file = os.path.join(args.prediction, args.split, f"camera_{frame_index}.json")
256
+
257
+ total_frames += 1
258
+
259
+ if not os.path.exists(prediction_file):
260
+ missed += 1
261
+
262
+ continue
263
+
264
+ with open(annotation_file, 'r') as f:
265
+ line_annotations = json.load(f)
266
+
267
+ with open(prediction_file, 'r') as f:
268
+ predictions = json.load(f)
269
+
270
+ line_annotations = scale_points(line_annotations, args.resolution_width, args.resolution_height)
271
+
272
+ image_path = os.path.join(args.soccernet, args.split, f"{frame_index}.jpg")
273
+
274
+ img_groundtruth = line_annotations
275
+
276
+ img_prediction = get_polylines(predictions, args.resolution_width, args.resolution_height,
277
+ sampling_factor=0.9)
278
+
279
+ confusion1, per_class_conf1, reproj_errors1 = evaluate_camera_prediction(img_prediction,
280
+ img_groundtruth,
281
+ args.threshold)
282
+
283
+ confusion2, per_class_conf2, reproj_errors2 = evaluate_camera_prediction(img_prediction,
284
+ mirror_labels(img_groundtruth),
285
+ args.threshold)
286
+
287
+ accuracy1, accuracy2 = 0., 0.
288
+ if confusion1.sum() > 0:
289
+ accuracy1 = confusion1[0, 0] / confusion1.sum()
290
+
291
+ if confusion2.sum() > 0:
292
+ accuracy2 = confusion2[0, 0] / confusion2.sum()
293
+
294
+ if accuracy1 > accuracy2:
295
+ accuracy = accuracy1
296
+ confusion = confusion1
297
+ per_class_conf = per_class_conf1
298
+ reproj_errors = reproj_errors1
299
+ else:
300
+ accuracy = accuracy2
301
+ confusion = confusion2
302
+ per_class_conf = per_class_conf2
303
+ reproj_errors = reproj_errors2
304
+
305
+ accuracies.append(accuracy)
306
+ if confusion[0, :].sum() > 0:
307
+ precision = confusion[0, 0] / (confusion[0, :].sum())
308
+ precisions.append(precision)
309
+ if (confusion[0, 0] + confusion[1, 0]) > 0:
310
+ recall = confusion[0, 0] / (confusion[0, 0] + confusion[1, 0])
311
+ recalls.append(recall)
312
+
313
+ for line_class, errors in reproj_errors.items():
314
+ if line_class in dict_errors.keys():
315
+ dict_errors[line_class].extend(errors)
316
+ else:
317
+ dict_errors[line_class] = errors
318
+
319
+ for line_class, confusion_mat in per_class_conf.items():
320
+ if line_class in per_class_confusion_dict.keys():
321
+ per_class_confusion_dict[line_class] += confusion_mat
322
+ else:
323
+ per_class_confusion_dict[line_class] = confusion_mat
324
+
325
+ completeness_score = (total_frames - missed) / total_frames
326
+ mAccuracy = np.mean(accuracies)
327
+
328
+ final_score = completeness_score * mAccuracy
329
+ print(f" On SoccerNet {args.split} set, final score of : {final_score}")
330
+ print(f" On SoccerNet {args.split} set, completeness rate of : {completeness_score}")
331
+
332
+ mRecall = np.mean(recalls)
333
+ sRecall = np.std(recalls)
334
+ medianRecall = np.median(recalls)
335
+ print(
336
+ f" On SoccerNet {args.split} set, recall mean value : {mRecall * 100:2.2f}% with standard deviation of {sRecall * 100:2.2f}% and median of {medianRecall * 100:2.2f}%")
337
+
338
+ mPrecision = np.mean(precisions)
339
+ sPrecision = np.std(precisions)
340
+ medianPrecision = np.median(precisions)
341
+ print(
342
+ f" On SoccerNet {args.split} set, precision mean value : {mPrecision * 100:2.2f}% with standard deviation of {sPrecision * 100:2.2f}% and median of {medianPrecision * 100:2.2f}%")
343
+
344
+ sAccuracy = np.std(accuracies)
345
+ medianAccuracy = np.median(accuracies)
346
+ print(
347
+ f" On SoccerNet {args.split} set, accuracy mean value : {mAccuracy * 100:2.2f}% with standard deviation of {sAccuracy * 100:2.2f}% and median of {medianAccuracy * 100:2.2f}%")
348
+
349
+ print()
350
+
351
+ for line_class, confusion_mat in per_class_confusion_dict.items():
352
+ class_accuracy = confusion_mat[0, 0] / confusion_mat.sum()
353
+ class_recall = confusion_mat[0, 0] / (confusion_mat[0, 0] + confusion_mat[1, 0])
354
+ class_precision = confusion_mat[0, 0] / (confusion_mat[0, 0] + confusion_mat[0, 1])
355
+ print(
356
+ f"For class {line_class}, accuracy of {class_accuracy * 100:2.2f}%, precision of {class_precision * 100:2.2f}% and recall of {class_recall * 100:2.2f}%")
357
+
358
+ for k, v in dict_errors.items():
359
+ fig, ax1 = plt.subplots(figsize=(11, 8))
360
+ ax1.hist(v, bins=30, range=(0, 60))
361
+ ax1.set_title(k)
362
+ ax1.set_xlabel("Errors in pixel")
363
+ os.makedirs(f"./results/", exist_ok=True)
364
+ plt.savefig(f"./results/{k}_reprojection_error.png")
365
+ plt.close(fig)
sn_calibration/src/evaluate_extremities.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import sys
4
+ import os
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+
10
+ sys.path.append("sn_calibration")
11
+ sys.path.append("sn_calibration/src")
12
+
13
+ from soccerpitch import SoccerPitch
14
+
15
+
16
+ def distance(point1, point2):
17
+ """
18
+ Computes euclidian distance between 2D points
19
+ :param point1
20
+ :param point2
21
+ :return: euclidian distance between point1 and point2
22
+ """
23
+ diff = np.array([point1['x'], point1['y']]) - np.array([point2['x'], point2['y']])
24
+ sq_dist = np.square(diff)
25
+ return np.sqrt(sq_dist.sum())
26
+
27
+
28
+ def mirror_labels(lines_dict):
29
+ """
30
+ Replace each line class key of the dictionary with its opposite element according to a central projection by the
31
+ soccer pitch center
32
+ :param lines_dict: dictionary whose keys will be mirrored
33
+ :return: Dictionary with mirrored keys and same values
34
+ """
35
+ mirrored_dict = dict()
36
+ for line_class, value in lines_dict.items():
37
+ mirrored_dict[SoccerPitch.symetric_classes[line_class]] = value
38
+ return mirrored_dict
39
+
40
+
41
+ def evaluate_detection_prediction(detected_lines, groundtruth_lines, threshold=2.):
42
+ """
43
+ Evaluates the prediction of extremities. The extremities associated to a class are unordered. The extremities of the
44
+ "Circle central" element is not well-defined for this task, thus this class is ignored.
45
+ Computes confusion matrices for a level of precision specified by the threshold.
46
+ A groundtruth extremity point is correctly classified if it lies at less than threshold pixels from the
47
+ corresponding extremity point of the prediction of the same class.
48
+ Computes also the euclidian distance between each predicted extremity and its closest groundtruth extremity, when
49
+ both the groundtruth and the prediction contain the element class.
50
+
51
+ :param detected_lines: dictionary of detected lines classes as keys and associated predicted extremities as values
52
+ :param groundtruth_lines: dictionary of annotated lines classes as keys and associated annotated points as values
53
+ :param threshold: distance in pixels that distinguishes good matches from bad ones
54
+ :return: confusion matrix, per class confusion matrix & per class localization errors
55
+ """
56
+ confusion_mat = np.zeros((2, 2), dtype=np.float32)
57
+ per_class_confusion = {}
58
+ errors_dict = {}
59
+ detected_classes = set(detected_lines.keys())
60
+ groundtruth_classes = set(groundtruth_lines.keys())
61
+
62
+ if "Circle central" in groundtruth_classes:
63
+ groundtruth_classes.remove("Circle central")
64
+ if "Circle central" in detected_classes:
65
+ detected_classes.remove("Circle central")
66
+
67
+ false_positives_classes = detected_classes - groundtruth_classes
68
+ for false_positive_class in false_positives_classes:
69
+ false_positives = len(detected_lines[false_positive_class])
70
+ confusion_mat[0, 1] += false_positives
71
+ per_class_confusion[false_positive_class] = np.array([[0., false_positives], [0., 0.]])
72
+
73
+ false_negatives_classes = groundtruth_classes - detected_classes
74
+ for false_negatives_class in false_negatives_classes:
75
+ false_negatives = len(groundtruth_lines[false_negatives_class])
76
+ confusion_mat[1, 0] += false_negatives
77
+ per_class_confusion[false_negatives_class] = np.array([[0., 0.], [false_negatives, 0.]])
78
+
79
+ common_classes = detected_classes - false_positives_classes
80
+
81
+ for detected_class in common_classes:
82
+
83
+ detected_points = detected_lines[detected_class]
84
+
85
+ groundtruth_points = groundtruth_lines[detected_class]
86
+
87
+ groundtruth_extremities = [groundtruth_points[0], groundtruth_points[-1]]
88
+ predicted_extremities = [detected_points[0], detected_points[-1]]
89
+ per_class_confusion[detected_class] = np.zeros((2, 2))
90
+
91
+ dist1 = distance(groundtruth_extremities[0], predicted_extremities[0])
92
+ dist1rev = distance(groundtruth_extremities[1], predicted_extremities[0])
93
+
94
+ dist2 = distance(groundtruth_extremities[1], predicted_extremities[1])
95
+ dist2rev = distance(groundtruth_extremities[0], predicted_extremities[1])
96
+ if dist1rev <= dist1 and dist2rev <= dist2:
97
+ # reverse order
98
+ dist1 = dist1rev
99
+ dist2 = dist2rev
100
+
101
+ errors_dict[detected_class] = [dist1, dist2]
102
+
103
+ if dist1 < threshold:
104
+ confusion_mat[0, 0] += 1
105
+ per_class_confusion[detected_class][0, 0] += 1
106
+ else:
107
+ # treat too far detections as false positives
108
+ confusion_mat[0, 1] += 1
109
+ per_class_confusion[detected_class][0, 1] += 1
110
+
111
+ if dist2 < threshold:
112
+ confusion_mat[0, 0] += 1
113
+ per_class_confusion[detected_class][0, 0] += 1
114
+
115
+ else:
116
+ # treat too far detections as false positives
117
+ confusion_mat[0, 1] += 1
118
+ per_class_confusion[detected_class][0, 1] += 1
119
+
120
+ return confusion_mat, per_class_confusion, errors_dict
121
+
122
+
123
+ def scale_points(points_dict, s_width, s_height):
124
+ """
125
+ Scale points by s_width and s_height factors
126
+ :param points_dict: dictionary of annotations/predictions with normalized point values
127
+ :param s_width: width scaling factor
128
+ :param s_height: height scaling factor
129
+ :return: dictionary with scaled points
130
+ """
131
+ line_dict = {}
132
+ for line_class, points in points_dict.items():
133
+ scaled_points = []
134
+ for point in points:
135
+ new_point = {'x': point['x'] * (s_width-1), 'y': point['y'] * (s_height-1)}
136
+ scaled_points.append(new_point)
137
+ if len(scaled_points):
138
+ line_dict[line_class] = scaled_points
139
+ return line_dict
140
+
141
+
142
+ if __name__ == "__main__":
143
+
144
+ parser = argparse.ArgumentParser(description='Test')
145
+
146
+ parser.add_argument('-s', '--soccernet', default="./annotations", type=str,
147
+ help='Path to the SoccerNet-V3 dataset folder')
148
+ parser.add_argument('-p', '--prediction', default="./results_bis",
149
+ required=False, type=str,
150
+ help="Path to the prediction folder")
151
+ parser.add_argument('-t', '--threshold', default=10, required=False, type=int,
152
+ help="Accuracy threshold in pixels")
153
+ parser.add_argument('--split', required=False, type=str, default="test", help='Select the split of data')
154
+ parser.add_argument('--resolution_width', required=False, type=int, default=960,
155
+ help='width resolution of the images')
156
+ parser.add_argument('--resolution_height', required=False, type=int, default=540,
157
+ help='height resolution of the images')
158
+ args = parser.parse_args()
159
+
160
+ accuracies = []
161
+ precisions = []
162
+ recalls = []
163
+ dict_errors = {}
164
+ per_class_confusion_dict = {}
165
+
166
+ dataset_dir = os.path.join(args.soccernet, args.split)
167
+ if not os.path.exists(dataset_dir):
168
+ print("Invalid dataset path !")
169
+ exit(-1)
170
+
171
+ annotation_files = [f for f in os.listdir(dataset_dir) if ".json" in f]
172
+
173
+ with tqdm(enumerate(annotation_files), total=len(annotation_files), ncols=160) as t:
174
+ for i, annotation_file in t:
175
+ frame_index = annotation_file.split(".")[0]
176
+ annotation_file = os.path.join(args.soccernet, args.split, annotation_file)
177
+ prediction_file = os.path.join(args.prediction, args.split, f"extremities_{frame_index}.json")
178
+
179
+ if not os.path.exists(prediction_file):
180
+ accuracies.append(0.)
181
+ precisions.append(0.)
182
+ recalls.append(0.)
183
+ continue
184
+
185
+ with open(annotation_file, 'r') as f:
186
+ line_annotations = json.load(f)
187
+
188
+ with open(prediction_file, 'r') as f:
189
+ predictions = json.load(f)
190
+
191
+ predictions = scale_points(predictions, args.resolution_width, args.resolution_height)
192
+ line_annotations = scale_points(line_annotations, args.resolution_width, args.resolution_height)
193
+
194
+ img_prediction = predictions
195
+ img_groundtruth = line_annotations
196
+ confusion1, per_class_conf1, reproj_errors1 = evaluate_detection_prediction(img_prediction,
197
+ img_groundtruth,
198
+ args.threshold)
199
+ confusion2, per_class_conf2, reproj_errors2 = evaluate_detection_prediction(img_prediction,
200
+ mirror_labels(
201
+ img_groundtruth),
202
+ args.threshold)
203
+
204
+ accuracy1, accuracy2 = 0., 0.
205
+ if confusion1.sum() > 0:
206
+ accuracy1 = confusion1[0, 0] / confusion1.sum()
207
+
208
+ if confusion2.sum() > 0:
209
+ accuracy2 = confusion2[0, 0] / confusion2.sum()
210
+
211
+ if accuracy1 > accuracy2:
212
+ accuracy = accuracy1
213
+ confusion = confusion1
214
+ per_class_conf = per_class_conf1
215
+ reproj_errors = reproj_errors1
216
+ else:
217
+ accuracy = accuracy2
218
+ confusion = confusion2
219
+ per_class_conf = per_class_conf2
220
+ reproj_errors = reproj_errors2
221
+
222
+ accuracies.append(accuracy)
223
+ if confusion[0, :].sum() > 0:
224
+ precision = confusion[0, 0] / (confusion[0, :].sum())
225
+ precisions.append(precision)
226
+ if (confusion[0, 0] + confusion[1, 0]) > 0:
227
+ recall = confusion[0, 0] / (confusion[0, 0] + confusion[1, 0])
228
+ recalls.append(recall)
229
+
230
+ for line_class, errors in reproj_errors.items():
231
+ if line_class in dict_errors.keys():
232
+ dict_errors[line_class].extend(errors)
233
+ else:
234
+ dict_errors[line_class] = errors
235
+
236
+ for line_class, confusion_mat in per_class_conf.items():
237
+ if line_class in per_class_confusion_dict.keys():
238
+ per_class_confusion_dict[line_class] += confusion_mat
239
+ else:
240
+ per_class_confusion_dict[line_class] = confusion_mat
241
+
242
+ mRecall = np.mean(recalls)
243
+ sRecall = np.std(recalls)
244
+ medianRecall = np.median(recalls)
245
+ print(
246
+ f" On SoccerNet {args.split} set, recall mean value : {mRecall * 100:2.2f}% with standard deviation of {sRecall * 100:2.2f}% and median of {medianRecall * 100:2.2f}%")
247
+
248
+ mPrecision = np.mean(precisions)
249
+ sPrecision = np.std(precisions)
250
+ medianPrecision = np.median(precisions)
251
+ print(
252
+ f" On SoccerNet {args.split} set, precision mean value : {mPrecision * 100:2.2f}% with standard deviation of {sPrecision * 100:2.2f}% and median of {medianPrecision * 100:2.2f}%")
253
+
254
+ mAccuracy = np.mean(accuracies)
255
+ sAccuracy = np.std(accuracies)
256
+ medianAccuracy = np.median(accuracies)
257
+ print(
258
+ f" On SoccerNet {args.split} set, accuracy mean value : {mAccuracy * 100:2.2f}% with standard deviation of {sAccuracy * 100:2.2f}% and median of {medianAccuracy * 100:2.2f}%")
259
+
260
+ for line_class, confusion_mat in per_class_confusion_dict.items():
261
+ class_accuracy = confusion_mat[0, 0] / confusion_mat.sum()
262
+ class_recall = confusion_mat[0, 0] / (confusion_mat[0, 0] + confusion_mat[1, 0])
263
+ class_precision = confusion_mat[0, 0] / (confusion_mat[0, 0] + confusion_mat[0, 1])
264
+ print(
265
+ f"For class {line_class}, accuracy of {class_accuracy * 100:2.2f}%, precision of {class_precision * 100:2.2f}% and recall of {class_recall * 100:2.2f}%")
266
+
267
+ for k, v in dict_errors.items():
268
+ fig, ax1 = plt.subplots(figsize=(11, 8))
269
+ ax1.hist(v, bins=30, range=(0, 60))
270
+ ax1.set_title(k)
271
+ ax1.set_xlabel("Errors in pixel")
272
+ os.makedirs(f"./results/", exist_ok=True)
273
+ plt.savefig(f"./results/{k}_detection_error.png")
274
+ plt.close(fig)
sn_calibration/src/soccerpitch.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class SoccerPitch:
5
+ """Static class variables that are specified by the rules of the game """
6
+ GOAL_LINE_TO_PENALTY_MARK = 11.0
7
+ PENALTY_AREA_WIDTH = 40.32
8
+ PENALTY_AREA_LENGTH = 16.5
9
+ GOAL_AREA_WIDTH = 18.32
10
+ GOAL_AREA_LENGTH = 5.5
11
+ CENTER_CIRCLE_RADIUS = 9.15
12
+ GOAL_HEIGHT = 2.44
13
+ GOAL_LENGTH = 7.32
14
+
15
+ lines_classes = [
16
+ 'Big rect. left bottom',
17
+ 'Big rect. left main',
18
+ 'Big rect. left top',
19
+ 'Big rect. right bottom',
20
+ 'Big rect. right main',
21
+ 'Big rect. right top',
22
+ 'Circle central',
23
+ 'Circle left',
24
+ 'Circle right',
25
+ 'Goal left crossbar',
26
+ 'Goal left post left ',
27
+ 'Goal left post right',
28
+ 'Goal right crossbar',
29
+ 'Goal right post left',
30
+ 'Goal right post right',
31
+ 'Goal unknown',
32
+ 'Line unknown',
33
+ 'Middle line',
34
+ 'Side line bottom',
35
+ 'Side line left',
36
+ 'Side line right',
37
+ 'Side line top',
38
+ 'Small rect. left bottom',
39
+ 'Small rect. left main',
40
+ 'Small rect. left top',
41
+ 'Small rect. right bottom',
42
+ 'Small rect. right main',
43
+ 'Small rect. right top'
44
+ ]
45
+
46
+ symetric_classes = {
47
+ 'Side line top': 'Side line bottom',
48
+ 'Side line bottom': 'Side line top',
49
+ 'Side line left': 'Side line right',
50
+ 'Middle line': 'Middle line',
51
+ 'Side line right': 'Side line left',
52
+ 'Big rect. left top': 'Big rect. right bottom',
53
+ 'Big rect. left bottom': 'Big rect. right top',
54
+ 'Big rect. left main': 'Big rect. right main',
55
+ 'Big rect. right top': 'Big rect. left bottom',
56
+ 'Big rect. right bottom': 'Big rect. left top',
57
+ 'Big rect. right main': 'Big rect. left main',
58
+ 'Small rect. left top': 'Small rect. right bottom',
59
+ 'Small rect. left bottom': 'Small rect. right top',
60
+ 'Small rect. left main': 'Small rect. right main',
61
+ 'Small rect. right top': 'Small rect. left bottom',
62
+ 'Small rect. right bottom': 'Small rect. left top',
63
+ 'Small rect. right main': 'Small rect. left main',
64
+ 'Circle left': 'Circle right',
65
+ 'Circle central': 'Circle central',
66
+ 'Circle right': 'Circle left',
67
+ 'Goal left crossbar': 'Goal right crossbar',
68
+ 'Goal left post left ': 'Goal right post left',
69
+ 'Goal left post right': 'Goal right post right',
70
+ 'Goal right crossbar': 'Goal left crossbar',
71
+ 'Goal right post left': 'Goal left post left ',
72
+ 'Goal right post right': 'Goal left post right',
73
+ 'Goal unknown': 'Goal unknown',
74
+ 'Line unknown': 'Line unknown'
75
+ }
76
+
77
+ # RGB values
78
+ palette = {
79
+ 'Big rect. left bottom': (127, 0, 0),
80
+ 'Big rect. left main': (102, 102, 102),
81
+ 'Big rect. left top': (0, 0, 127),
82
+ 'Big rect. right bottom': (86, 32, 39),
83
+ 'Big rect. right main': (48, 77, 0),
84
+ 'Big rect. right top': (14, 97, 100),
85
+ 'Circle central': (0, 0, 255),
86
+ 'Circle left': (255, 127, 0),
87
+ 'Circle right': (0, 255, 255),
88
+ 'Goal left crossbar': (255, 255, 200),
89
+ 'Goal left post left ': (165, 255, 0),
90
+ 'Goal left post right': (155, 119, 45),
91
+ 'Goal right crossbar': (86, 32, 139),
92
+ 'Goal right post left': (196, 120, 153),
93
+ 'Goal right post right': (166, 36, 52),
94
+ 'Goal unknown': (0, 0, 0),
95
+ 'Line unknown': (0, 0, 0),
96
+ 'Middle line': (255, 255, 0),
97
+ 'Side line bottom': (255, 0, 255),
98
+ 'Side line left': (0, 255, 150),
99
+ 'Side line right': (0, 230, 0),
100
+ 'Side line top': (230, 0, 0),
101
+ 'Small rect. left bottom': (0, 150, 255),
102
+ 'Small rect. left main': (254, 173, 225),
103
+ 'Small rect. left top': (87, 72, 39),
104
+ 'Small rect. right bottom': (122, 0, 255),
105
+ 'Small rect. right main': (255, 255, 255),
106
+ 'Small rect. right top': (153, 23, 153)
107
+ }
108
+
109
+ def __init__(self, pitch_length=105., pitch_width=68.):
110
+ """
111
+ Initialize 3D coordinates of all elements of the soccer pitch.
112
+ :param pitch_length: According to FIFA rules, length belong to [90,120] meters
113
+ :param pitch_width: According to FIFA rules, length belong to [45,90] meters
114
+ """
115
+ self.PITCH_LENGTH = pitch_length
116
+ self.PITCH_WIDTH = pitch_width
117
+
118
+ self.center_mark = np.array([0, 0, 0], dtype='float')
119
+ self.halfway_and_bottom_touch_line_mark = np.array([0, pitch_width / 2.0, 0], dtype='float')
120
+ self.halfway_and_top_touch_line_mark = np.array([0, -pitch_width / 2.0, 0], dtype='float')
121
+ self.halfway_line_and_center_circle_top_mark = np.array([0, -SoccerPitch.CENTER_CIRCLE_RADIUS, 0],
122
+ dtype='float')
123
+ self.halfway_line_and_center_circle_bottom_mark = np.array([0, SoccerPitch.CENTER_CIRCLE_RADIUS, 0],
124
+ dtype='float')
125
+ self.bottom_right_corner = np.array([pitch_length / 2.0, pitch_width / 2.0, 0], dtype='float')
126
+ self.bottom_left_corner = np.array([-pitch_length / 2.0, pitch_width / 2.0, 0], dtype='float')
127
+ self.top_right_corner = np.array([pitch_length / 2.0, -pitch_width / 2.0, 0], dtype='float')
128
+ self.top_left_corner = np.array([-pitch_length / 2.0, -pitch_width / 2.0, 0], dtype='float')
129
+
130
+ self.left_goal_bottom_left_post = np.array([-pitch_length / 2.0, SoccerPitch.GOAL_LENGTH / 2., 0.],
131
+ dtype='float')
132
+ self.left_goal_top_left_post = np.array(
133
+ [-pitch_length / 2.0, SoccerPitch.GOAL_LENGTH / 2., -SoccerPitch.GOAL_HEIGHT], dtype='float')
134
+ self.left_goal_bottom_right_post = np.array([-pitch_length / 2.0, -SoccerPitch.GOAL_LENGTH / 2., 0.],
135
+ dtype='float')
136
+ self.left_goal_top_right_post = np.array(
137
+ [-pitch_length / 2.0, -SoccerPitch.GOAL_LENGTH / 2., -SoccerPitch.GOAL_HEIGHT], dtype='float')
138
+
139
+ self.right_goal_bottom_left_post = np.array([pitch_length / 2.0, -SoccerPitch.GOAL_LENGTH / 2., 0.],
140
+ dtype='float')
141
+ self.right_goal_top_left_post = np.array(
142
+ [pitch_length / 2.0, -SoccerPitch.GOAL_LENGTH / 2., -SoccerPitch.GOAL_HEIGHT], dtype='float')
143
+ self.right_goal_bottom_right_post = np.array([pitch_length / 2.0, SoccerPitch.GOAL_LENGTH / 2., 0.],
144
+ dtype='float')
145
+ self.right_goal_top_right_post = np.array(
146
+ [pitch_length / 2.0, SoccerPitch.GOAL_LENGTH / 2., -SoccerPitch.GOAL_HEIGHT], dtype='float')
147
+
148
+ self.left_penalty_mark = np.array([-pitch_length / 2.0 + SoccerPitch.GOAL_LINE_TO_PENALTY_MARK, 0, 0],
149
+ dtype='float')
150
+ self.right_penalty_mark = np.array([pitch_length / 2.0 - SoccerPitch.GOAL_LINE_TO_PENALTY_MARK, 0, 0],
151
+ dtype='float')
152
+
153
+ self.left_penalty_area_top_right_corner = np.array(
154
+ [-pitch_length / 2.0 + SoccerPitch.PENALTY_AREA_LENGTH, -SoccerPitch.PENALTY_AREA_WIDTH / 2.0, 0],
155
+ dtype='float')
156
+ self.left_penalty_area_top_left_corner = np.array(
157
+ [-pitch_length / 2.0, -SoccerPitch.PENALTY_AREA_WIDTH / 2.0, 0],
158
+ dtype='float')
159
+ self.left_penalty_area_bottom_right_corner = np.array(
160
+ [-pitch_length / 2.0 + SoccerPitch.PENALTY_AREA_LENGTH, SoccerPitch.PENALTY_AREA_WIDTH / 2.0, 0],
161
+ dtype='float')
162
+ self.left_penalty_area_bottom_left_corner = np.array(
163
+ [-pitch_length / 2.0, SoccerPitch.PENALTY_AREA_WIDTH / 2.0, 0],
164
+ dtype='float')
165
+ self.right_penalty_area_top_right_corner = np.array(
166
+ [pitch_length / 2.0, -SoccerPitch.PENALTY_AREA_WIDTH / 2.0, 0],
167
+ dtype='float')
168
+ self.right_penalty_area_top_left_corner = np.array(
169
+ [pitch_length / 2.0 - SoccerPitch.PENALTY_AREA_LENGTH, -SoccerPitch.PENALTY_AREA_WIDTH / 2.0, 0],
170
+ dtype='float')
171
+ self.right_penalty_area_bottom_right_corner = np.array(
172
+ [pitch_length / 2.0, SoccerPitch.PENALTY_AREA_WIDTH / 2.0, 0],
173
+ dtype='float')
174
+ self.right_penalty_area_bottom_left_corner = np.array(
175
+ [pitch_length / 2.0 - SoccerPitch.PENALTY_AREA_LENGTH, SoccerPitch.PENALTY_AREA_WIDTH / 2.0, 0],
176
+ dtype='float')
177
+
178
+ self.left_goal_area_top_right_corner = np.array(
179
+ [-pitch_length / 2.0 + SoccerPitch.GOAL_AREA_LENGTH, -SoccerPitch.GOAL_AREA_WIDTH / 2.0, 0], dtype='float')
180
+ self.left_goal_area_top_left_corner = np.array([-pitch_length / 2.0, - SoccerPitch.GOAL_AREA_WIDTH / 2.0, 0],
181
+ dtype='float')
182
+ self.left_goal_area_bottom_right_corner = np.array(
183
+ [-pitch_length / 2.0 + SoccerPitch.GOAL_AREA_LENGTH, SoccerPitch.GOAL_AREA_WIDTH / 2.0, 0], dtype='float')
184
+ self.left_goal_area_bottom_left_corner = np.array([-pitch_length / 2.0, SoccerPitch.GOAL_AREA_WIDTH / 2.0, 0],
185
+ dtype='float')
186
+ self.right_goal_area_top_right_corner = np.array([pitch_length / 2.0, -SoccerPitch.GOAL_AREA_WIDTH / 2.0, 0],
187
+ dtype='float')
188
+ self.right_goal_area_top_left_corner = np.array(
189
+ [pitch_length / 2.0 - SoccerPitch.GOAL_AREA_LENGTH, -SoccerPitch.GOAL_AREA_WIDTH / 2.0, 0], dtype='float')
190
+ self.right_goal_area_bottom_right_corner = np.array([pitch_length / 2.0, SoccerPitch.GOAL_AREA_WIDTH / 2.0, 0],
191
+ dtype='float')
192
+ self.right_goal_area_bottom_left_corner = np.array(
193
+ [pitch_length / 2.0 - SoccerPitch.GOAL_AREA_LENGTH, SoccerPitch.GOAL_AREA_WIDTH / 2.0, 0], dtype='float')
194
+
195
+ x = -pitch_length / 2.0 + SoccerPitch.PENALTY_AREA_LENGTH;
196
+ dx = SoccerPitch.PENALTY_AREA_LENGTH - SoccerPitch.GOAL_LINE_TO_PENALTY_MARK;
197
+ y = -np.sqrt(SoccerPitch.CENTER_CIRCLE_RADIUS * SoccerPitch.CENTER_CIRCLE_RADIUS - dx * dx);
198
+ self.top_left_16M_penalty_arc_mark = np.array([x, y, 0], dtype='float')
199
+
200
+ x = pitch_length / 2.0 - SoccerPitch.PENALTY_AREA_LENGTH;
201
+ dx = SoccerPitch.PENALTY_AREA_LENGTH - SoccerPitch.GOAL_LINE_TO_PENALTY_MARK;
202
+ y = -np.sqrt(SoccerPitch.CENTER_CIRCLE_RADIUS * SoccerPitch.CENTER_CIRCLE_RADIUS - dx * dx);
203
+ self.top_right_16M_penalty_arc_mark = np.array([x, y, 0], dtype='float')
204
+
205
+ x = -pitch_length / 2.0 + SoccerPitch.PENALTY_AREA_LENGTH;
206
+ dx = SoccerPitch.PENALTY_AREA_LENGTH - SoccerPitch.GOAL_LINE_TO_PENALTY_MARK;
207
+ y = np.sqrt(SoccerPitch.CENTER_CIRCLE_RADIUS * SoccerPitch.CENTER_CIRCLE_RADIUS - dx * dx);
208
+ self.bottom_left_16M_penalty_arc_mark = np.array([x, y, 0], dtype='float')
209
+
210
+ x = pitch_length / 2.0 - SoccerPitch.PENALTY_AREA_LENGTH;
211
+ dx = SoccerPitch.PENALTY_AREA_LENGTH - SoccerPitch.GOAL_LINE_TO_PENALTY_MARK;
212
+ y = np.sqrt(SoccerPitch.CENTER_CIRCLE_RADIUS * SoccerPitch.CENTER_CIRCLE_RADIUS - dx * dx);
213
+ self.bottom_right_16M_penalty_arc_mark = np.array([x, y, 0], dtype='float')
214
+
215
+ # self.set_elevations(elevation)
216
+
217
+ self.point_dict = {}
218
+ self.point_dict["CENTER_MARK"] = self.center_mark
219
+ self.point_dict["L_PENALTY_MARK"] = self.left_penalty_mark
220
+ self.point_dict["R_PENALTY_MARK"] = self.right_penalty_mark
221
+ self.point_dict["TL_PITCH_CORNER"] = self.top_left_corner
222
+ self.point_dict["BL_PITCH_CORNER"] = self.bottom_left_corner
223
+ self.point_dict["TR_PITCH_CORNER"] = self.top_right_corner
224
+ self.point_dict["BR_PITCH_CORNER"] = self.bottom_right_corner
225
+ self.point_dict["L_PENALTY_AREA_TL_CORNER"] = self.left_penalty_area_top_left_corner
226
+ self.point_dict["L_PENALTY_AREA_TR_CORNER"] = self.left_penalty_area_top_right_corner
227
+ self.point_dict["L_PENALTY_AREA_BL_CORNER"] = self.left_penalty_area_bottom_left_corner
228
+ self.point_dict["L_PENALTY_AREA_BR_CORNER"] = self.left_penalty_area_bottom_right_corner
229
+
230
+ self.point_dict["R_PENALTY_AREA_TL_CORNER"] = self.right_penalty_area_top_left_corner
231
+ self.point_dict["R_PENALTY_AREA_TR_CORNER"] = self.right_penalty_area_top_right_corner
232
+ self.point_dict["R_PENALTY_AREA_BL_CORNER"] = self.right_penalty_area_bottom_left_corner
233
+ self.point_dict["R_PENALTY_AREA_BR_CORNER"] = self.right_penalty_area_bottom_right_corner
234
+
235
+ self.point_dict["L_GOAL_AREA_TL_CORNER"] = self.left_goal_area_top_left_corner
236
+ self.point_dict["L_GOAL_AREA_TR_CORNER"] = self.left_goal_area_top_right_corner
237
+ self.point_dict["L_GOAL_AREA_BL_CORNER"] = self.left_goal_area_bottom_left_corner
238
+ self.point_dict["L_GOAL_AREA_BR_CORNER"] = self.left_goal_area_bottom_right_corner
239
+
240
+ self.point_dict["R_GOAL_AREA_TL_CORNER"] = self.right_goal_area_top_left_corner
241
+ self.point_dict["R_GOAL_AREA_TR_CORNER"] = self.right_goal_area_top_right_corner
242
+ self.point_dict["R_GOAL_AREA_BL_CORNER"] = self.right_goal_area_bottom_left_corner
243
+ self.point_dict["R_GOAL_AREA_BR_CORNER"] = self.right_goal_area_bottom_right_corner
244
+
245
+ self.point_dict["L_GOAL_TL_POST"] = self.left_goal_top_left_post
246
+ self.point_dict["L_GOAL_TR_POST"] = self.left_goal_top_right_post
247
+ self.point_dict["L_GOAL_BL_POST"] = self.left_goal_bottom_left_post
248
+ self.point_dict["L_GOAL_BR_POST"] = self.left_goal_bottom_right_post
249
+
250
+ self.point_dict["R_GOAL_TL_POST"] = self.right_goal_top_left_post
251
+ self.point_dict["R_GOAL_TR_POST"] = self.right_goal_top_right_post
252
+ self.point_dict["R_GOAL_BL_POST"] = self.right_goal_bottom_left_post
253
+ self.point_dict["R_GOAL_BR_POST"] = self.right_goal_bottom_right_post
254
+
255
+ self.point_dict["T_TOUCH_AND_HALFWAY_LINES_INTERSECTION"] = self.halfway_and_top_touch_line_mark
256
+ self.point_dict["B_TOUCH_AND_HALFWAY_LINES_INTERSECTION"] = self.halfway_and_bottom_touch_line_mark
257
+ self.point_dict["T_HALFWAY_LINE_AND_CENTER_CIRCLE_INTERSECTION"] = self.halfway_line_and_center_circle_top_mark
258
+ self.point_dict[
259
+ "B_HALFWAY_LINE_AND_CENTER_CIRCLE_INTERSECTION"] = self.halfway_line_and_center_circle_bottom_mark
260
+ self.point_dict["TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"] = self.top_left_16M_penalty_arc_mark
261
+ self.point_dict["BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"] = self.bottom_left_16M_penalty_arc_mark
262
+ self.point_dict["TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"] = self.top_right_16M_penalty_arc_mark
263
+ self.point_dict["BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"] = self.bottom_right_16M_penalty_arc_mark
264
+
265
+ self.line_extremities = dict()
266
+ self.line_extremities["Big rect. left bottom"] = (self.point_dict["L_PENALTY_AREA_BL_CORNER"],
267
+ self.point_dict["L_PENALTY_AREA_BR_CORNER"])
268
+ self.line_extremities["Big rect. left top"] = (self.point_dict["L_PENALTY_AREA_TL_CORNER"],
269
+ self.point_dict["L_PENALTY_AREA_TR_CORNER"])
270
+ self.line_extremities["Big rect. left main"] = (self.point_dict["L_PENALTY_AREA_TR_CORNER"],
271
+ self.point_dict["L_PENALTY_AREA_BR_CORNER"])
272
+ self.line_extremities["Big rect. right bottom"] = (self.point_dict["R_PENALTY_AREA_BL_CORNER"],
273
+ self.point_dict["R_PENALTY_AREA_BR_CORNER"])
274
+ self.line_extremities["Big rect. right top"] = (self.point_dict["R_PENALTY_AREA_TL_CORNER"],
275
+ self.point_dict["R_PENALTY_AREA_TR_CORNER"])
276
+ self.line_extremities["Big rect. right main"] = (self.point_dict["R_PENALTY_AREA_TL_CORNER"],
277
+ self.point_dict["R_PENALTY_AREA_BL_CORNER"])
278
+
279
+ self.line_extremities["Small rect. left bottom"] = (self.point_dict["L_GOAL_AREA_BL_CORNER"],
280
+ self.point_dict["L_GOAL_AREA_BR_CORNER"])
281
+ self.line_extremities["Small rect. left top"] = (self.point_dict["L_GOAL_AREA_TL_CORNER"],
282
+ self.point_dict["L_GOAL_AREA_TR_CORNER"])
283
+ self.line_extremities["Small rect. left main"] = (self.point_dict["L_GOAL_AREA_TR_CORNER"],
284
+ self.point_dict["L_GOAL_AREA_BR_CORNER"])
285
+ self.line_extremities["Small rect. right bottom"] = (self.point_dict["R_GOAL_AREA_BL_CORNER"],
286
+ self.point_dict["R_GOAL_AREA_BR_CORNER"])
287
+ self.line_extremities["Small rect. right top"] = (self.point_dict["R_GOAL_AREA_TL_CORNER"],
288
+ self.point_dict["R_GOAL_AREA_TR_CORNER"])
289
+ self.line_extremities["Small rect. right main"] = (self.point_dict["R_GOAL_AREA_TL_CORNER"],
290
+ self.point_dict["R_GOAL_AREA_BL_CORNER"])
291
+
292
+ self.line_extremities["Side line top"] = (self.point_dict["TL_PITCH_CORNER"],
293
+ self.point_dict["TR_PITCH_CORNER"])
294
+ self.line_extremities["Side line bottom"] = (self.point_dict["BL_PITCH_CORNER"],
295
+ self.point_dict["BR_PITCH_CORNER"])
296
+ self.line_extremities["Side line left"] = (self.point_dict["TL_PITCH_CORNER"],
297
+ self.point_dict["BL_PITCH_CORNER"])
298
+ self.line_extremities["Side line right"] = (self.point_dict["TR_PITCH_CORNER"],
299
+ self.point_dict["BR_PITCH_CORNER"])
300
+ self.line_extremities["Middle line"] = (self.point_dict["T_TOUCH_AND_HALFWAY_LINES_INTERSECTION"],
301
+ self.point_dict["B_TOUCH_AND_HALFWAY_LINES_INTERSECTION"])
302
+
303
+ self.line_extremities["Goal left crossbar"] = (self.point_dict["L_GOAL_TR_POST"],
304
+ self.point_dict["L_GOAL_TL_POST"])
305
+ self.line_extremities["Goal left post left "] = (self.point_dict["L_GOAL_TL_POST"],
306
+ self.point_dict["L_GOAL_BL_POST"])
307
+ self.line_extremities["Goal left post right"] = (self.point_dict["L_GOAL_TR_POST"],
308
+ self.point_dict["L_GOAL_BR_POST"])
309
+
310
+ self.line_extremities["Goal right crossbar"] = (self.point_dict["R_GOAL_TL_POST"],
311
+ self.point_dict["R_GOAL_TR_POST"])
312
+ self.line_extremities["Goal right post left"] = (self.point_dict["R_GOAL_TL_POST"],
313
+ self.point_dict["R_GOAL_BL_POST"])
314
+ self.line_extremities["Goal right post right"] = (self.point_dict["R_GOAL_TR_POST"],
315
+ self.point_dict["R_GOAL_BR_POST"])
316
+ self.line_extremities["Circle right"] = (self.point_dict["TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
317
+ self.point_dict["BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"])
318
+ self.line_extremities["Circle left"] = (self.point_dict["TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
319
+ self.point_dict["BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"])
320
+
321
+ self.line_extremities_keys = dict()
322
+ self.line_extremities_keys["Big rect. left bottom"] = ("L_PENALTY_AREA_BL_CORNER",
323
+ "L_PENALTY_AREA_BR_CORNER")
324
+ self.line_extremities_keys["Big rect. left top"] = ("L_PENALTY_AREA_TL_CORNER",
325
+ "L_PENALTY_AREA_TR_CORNER")
326
+ self.line_extremities_keys["Big rect. left main"] = ("L_PENALTY_AREA_TR_CORNER",
327
+ "L_PENALTY_AREA_BR_CORNER")
328
+ self.line_extremities_keys["Big rect. right bottom"] = ("R_PENALTY_AREA_BL_CORNER",
329
+ "R_PENALTY_AREA_BR_CORNER")
330
+ self.line_extremities_keys["Big rect. right top"] = ("R_PENALTY_AREA_TL_CORNER",
331
+ "R_PENALTY_AREA_TR_CORNER")
332
+ self.line_extremities_keys["Big rect. right main"] = ("R_PENALTY_AREA_TL_CORNER",
333
+ "R_PENALTY_AREA_BL_CORNER")
334
+
335
+ self.line_extremities_keys["Small rect. left bottom"] = ("L_GOAL_AREA_BL_CORNER",
336
+ "L_GOAL_AREA_BR_CORNER")
337
+ self.line_extremities_keys["Small rect. left top"] = ("L_GOAL_AREA_TL_CORNER",
338
+ "L_GOAL_AREA_TR_CORNER")
339
+ self.line_extremities_keys["Small rect. left main"] = ("L_GOAL_AREA_TR_CORNER",
340
+ "L_GOAL_AREA_BR_CORNER")
341
+ self.line_extremities_keys["Small rect. right bottom"] = ("R_GOAL_AREA_BL_CORNER",
342
+ "R_GOAL_AREA_BR_CORNER")
343
+ self.line_extremities_keys["Small rect. right top"] = ("R_GOAL_AREA_TL_CORNER",
344
+ "R_GOAL_AREA_TR_CORNER")
345
+ self.line_extremities_keys["Small rect. right main"] = ("R_GOAL_AREA_TL_CORNER",
346
+ "R_GOAL_AREA_BL_CORNER")
347
+
348
+ self.line_extremities_keys["Side line top"] = ("TL_PITCH_CORNER",
349
+ "TR_PITCH_CORNER")
350
+ self.line_extremities_keys["Side line bottom"] = ("BL_PITCH_CORNER",
351
+ "BR_PITCH_CORNER")
352
+ self.line_extremities_keys["Side line left"] = ("TL_PITCH_CORNER",
353
+ "BL_PITCH_CORNER")
354
+ self.line_extremities_keys["Side line right"] = ("TR_PITCH_CORNER",
355
+ "BR_PITCH_CORNER")
356
+ self.line_extremities_keys["Middle line"] = ("T_TOUCH_AND_HALFWAY_LINES_INTERSECTION",
357
+ "B_TOUCH_AND_HALFWAY_LINES_INTERSECTION")
358
+
359
+ self.line_extremities_keys["Goal left crossbar"] = ("L_GOAL_TR_POST",
360
+ "L_GOAL_TL_POST")
361
+ self.line_extremities_keys["Goal left post left "] = ("L_GOAL_TL_POST",
362
+ "L_GOAL_BL_POST")
363
+ self.line_extremities_keys["Goal left post right"] = ("L_GOAL_TR_POST",
364
+ "L_GOAL_BR_POST")
365
+
366
+ self.line_extremities_keys["Goal right crossbar"] = ("R_GOAL_TL_POST",
367
+ "R_GOAL_TR_POST")
368
+ self.line_extremities_keys["Goal right post left"] = ("R_GOAL_TL_POST",
369
+ "R_GOAL_BL_POST")
370
+ self.line_extremities_keys["Goal right post right"] = ("R_GOAL_TR_POST",
371
+ "R_GOAL_BR_POST")
372
+ self.line_extremities_keys["Circle right"] = ("TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
373
+ "BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION")
374
+ self.line_extremities_keys["Circle left"] = ("TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
375
+ "BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION")
376
+
377
+ def points(self):
378
+ return [
379
+ self.center_mark,
380
+ self.halfway_and_bottom_touch_line_mark,
381
+ self.halfway_and_top_touch_line_mark,
382
+ self.halfway_line_and_center_circle_top_mark,
383
+ self.halfway_line_and_center_circle_bottom_mark,
384
+ self.bottom_right_corner,
385
+ self.bottom_left_corner,
386
+ self.top_right_corner,
387
+ self.top_left_corner,
388
+ self.left_penalty_mark,
389
+ self.right_penalty_mark,
390
+ self.left_penalty_area_top_right_corner,
391
+ self.left_penalty_area_top_left_corner,
392
+ self.left_penalty_area_bottom_right_corner,
393
+ self.left_penalty_area_bottom_left_corner,
394
+ self.right_penalty_area_top_right_corner,
395
+ self.right_penalty_area_top_left_corner,
396
+ self.right_penalty_area_bottom_right_corner,
397
+ self.right_penalty_area_bottom_left_corner,
398
+ self.left_goal_area_top_right_corner,
399
+ self.left_goal_area_top_left_corner,
400
+ self.left_goal_area_bottom_right_corner,
401
+ self.left_goal_area_bottom_left_corner,
402
+ self.right_goal_area_top_right_corner,
403
+ self.right_goal_area_top_left_corner,
404
+ self.right_goal_area_bottom_right_corner,
405
+ self.right_goal_area_bottom_left_corner,
406
+ self.top_left_16M_penalty_arc_mark,
407
+ self.top_right_16M_penalty_arc_mark,
408
+ self.bottom_left_16M_penalty_arc_mark,
409
+ self.bottom_right_16M_penalty_arc_mark,
410
+ self.left_goal_top_left_post,
411
+ self.left_goal_top_right_post,
412
+ self.left_goal_bottom_left_post,
413
+ self.left_goal_bottom_right_post,
414
+ self.right_goal_top_left_post,
415
+ self.right_goal_top_right_post,
416
+ self.right_goal_bottom_left_post,
417
+ self.right_goal_bottom_right_post
418
+ ]
419
+
420
+ def sample_field_points(self, dist=0.1, dist_circles=0.2):
421
+ """
422
+ Samples each pitch element every dist meters, returns a dictionary associating the class of the element with a list of points sampled along this element.
423
+ :param dist: the distance in meters between each point sampled
424
+ :param dist_circles: the distance in meters between each point sampled on circles
425
+ :return: a dictionary associating the class of the element with a list of points sampled along this element.
426
+ """
427
+ polylines = dict()
428
+ center = self.point_dict["CENTER_MARK"]
429
+ fromAngle = 0.
430
+ toAngle = 2 * np.pi
431
+
432
+ if toAngle < fromAngle:
433
+ toAngle += 2 * np.pi
434
+ x1 = center[0] + np.cos(fromAngle) * SoccerPitch.CENTER_CIRCLE_RADIUS
435
+ y1 = center[1] + np.sin(fromAngle) * SoccerPitch.CENTER_CIRCLE_RADIUS
436
+ z1 = 0.
437
+ point = np.array((x1, y1, z1))
438
+ polyline = [point]
439
+ length = SoccerPitch.CENTER_CIRCLE_RADIUS * (toAngle - fromAngle)
440
+ nb_pts = int(length / dist_circles)
441
+ dangle = dist_circles / SoccerPitch.CENTER_CIRCLE_RADIUS
442
+ for i in range(1, nb_pts):
443
+ angle = fromAngle + i * dangle
444
+ x = center[0] + np.cos(angle) * SoccerPitch.CENTER_CIRCLE_RADIUS
445
+ y = center[1] + np.sin(angle) * SoccerPitch.CENTER_CIRCLE_RADIUS
446
+ z = 0
447
+ point = np.array((x, y, z))
448
+ polyline.append(point)
449
+ polylines["Circle central"] = polyline
450
+ for key, line in self.line_extremities.items():
451
+
452
+ if "Circle" in key:
453
+ if key == "Circle right":
454
+ top = self.point_dict["TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
455
+ bottom = self.point_dict["BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
456
+ center = self.point_dict["R_PENALTY_MARK"]
457
+ toAngle = np.arctan2(top[1] - center[1],
458
+ top[0] - center[0]) + 2 * np.pi
459
+ fromAngle = np.arctan2(bottom[1] - center[1],
460
+ bottom[0] - center[0]) + 2 * np.pi
461
+ elif key == "Circle left":
462
+ top = self.point_dict["TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
463
+ bottom = self.point_dict["BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
464
+ center = self.point_dict["L_PENALTY_MARK"]
465
+ fromAngle = np.arctan2(top[1] - center[1],
466
+ top[0] - center[0]) + 2 * np.pi
467
+ toAngle = np.arctan2(bottom[1] - center[1],
468
+ bottom[0] - center[0]) + 2 * np.pi
469
+ if toAngle < fromAngle:
470
+ toAngle += 2 * np.pi
471
+ x1 = center[0] + np.cos(fromAngle) * SoccerPitch.CENTER_CIRCLE_RADIUS
472
+ y1 = center[1] + np.sin(fromAngle) * SoccerPitch.CENTER_CIRCLE_RADIUS
473
+ z1 = 0.
474
+ xn = center[0] + np.cos(toAngle) * SoccerPitch.CENTER_CIRCLE_RADIUS
475
+ yn = center[1] + np.sin(toAngle) * SoccerPitch.CENTER_CIRCLE_RADIUS
476
+ zn = 0.
477
+ start = np.array((x1, y1, z1))
478
+ end = np.array((xn, yn, zn))
479
+ polyline = [start]
480
+ length = SoccerPitch.CENTER_CIRCLE_RADIUS * (toAngle - fromAngle)
481
+ nb_pts = int(length / dist_circles)
482
+ dangle = dist_circles / SoccerPitch.CENTER_CIRCLE_RADIUS
483
+ for i in range(1, nb_pts + 1):
484
+ angle = fromAngle + i * dangle
485
+ x = center[0] + np.cos(angle) * SoccerPitch.CENTER_CIRCLE_RADIUS
486
+ y = center[1] + np.sin(angle) * SoccerPitch.CENTER_CIRCLE_RADIUS
487
+ z = 0
488
+ point = np.array((x, y, z))
489
+ polyline.append(point)
490
+ polyline.append(end)
491
+ polylines[key] = polyline
492
+ else:
493
+ start = line[0]
494
+ end = line[1]
495
+
496
+ polyline = [start]
497
+
498
+ total_dist = np.sqrt(np.sum(np.square(start - end)))
499
+ nb_pts = int(total_dist / dist - 1)
500
+
501
+ v = end - start
502
+ v /= np.linalg.norm(v)
503
+ prev_pt = start
504
+ for i in range(nb_pts):
505
+ pt = prev_pt + dist * v
506
+ prev_pt = pt
507
+ polyline.append(pt)
508
+ polyline.append(end)
509
+ polylines[key] = polyline
510
+ return polylines
511
+
512
+ def get_2d_homogeneous_line(self, line_name):
513
+ """
514
+ For lines belonging to the pitch lawn plane returns its 2D homogenous equation coefficients
515
+ :param line_name
516
+ :return: an array containing the three coefficients of the line
517
+ """
518
+ # ensure line in football pitch plane
519
+ if line_name in self.line_extremities.keys() and \
520
+ "post" not in line_name and \
521
+ "crossbar" not in line_name and "Circle" not in line_name:
522
+ extremities = self.line_extremities[line_name]
523
+ p1 = np.array([extremities[0][0], extremities[0][1], 1], dtype="float")
524
+ p2 = np.array([extremities[1][0], extremities[1][1], 1], dtype="float")
525
+ line = np.cross(p1, p2)
526
+
527
+ return line
528
+ return None