Upload 63 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +68 -0
- .gitignore +109 -0
- Dockerfile +10 -0
- LICENSE +339 -0
- PnLCalib.yml +235 -0
- README.md +134 -10
- api.py +162 -0
- config.yaml +33 -0
- config/field_config.py +164 -0
- config/hrnetv2_w48.yaml +35 -0
- config/hrnetv2_w48_l.yaml +35 -0
- data/__init__.py +1 -0
- data/line_data.py +115 -0
- get_camera_params.py +201 -0
- inference.py +286 -0
- model/cls_hrnet.py +479 -0
- model/cls_hrnet_l.py +478 -0
- model/dataloader.py +219 -0
- model/dataloader_l.py +202 -0
- model/losses.py +221 -0
- model/metrics.py +312 -0
- model/transforms.py +364 -0
- model/transformsWC.py +212 -0
- model/transformsWC_l.py +208 -0
- model/transforms_l.py +360 -0
- requirements.txt +44 -0
- run_api.py +13 -0
- scripts/eval_tswc.py +175 -0
- scripts/eval_wc14.py +154 -0
- scripts/inference_sn.py +186 -0
- scripts/inference_tswc.py +145 -0
- scripts/inference_wc14.py +148 -0
- scripts/run_pipeline_sn22.sh +23 -0
- scripts/run_pipeline_sn23.sh +23 -0
- scripts/run_pipeline_tswc.sh +20 -0
- scripts/run_pipeline_wc14.sh +20 -0
- scripts/run_pipeline_wc14_3D.sh +23 -0
- sn_calibration/ChallengeRules.md +44 -0
- sn_calibration/README.md +440 -0
- sn_calibration/requirements.txt +17 -0
- sn_calibration/resources/mean.npy +3 -0
- sn_calibration/resources/std.npy +3 -0
- sn_calibration/src/baseline_cameras.py +251 -0
- sn_calibration/src/camera.py +402 -0
- sn_calibration/src/dataloader.py +122 -0
- sn_calibration/src/detect_extremities.py +314 -0
- sn_calibration/src/evalai_camera.py +133 -0
- sn_calibration/src/evaluate_camera.py +365 -0
- sn_calibration/src/evaluate_extremities.py +274 -0
- sn_calibration/src/soccerpitch.py +528 -0
.dockerignore
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Exclure les environnements virtuels
|
2 |
+
.venv/
|
3 |
+
env/
|
4 |
+
venv/
|
5 |
+
ENV/
|
6 |
+
env.bak/
|
7 |
+
venv.bak/
|
8 |
+
|
9 |
+
# Python cache
|
10 |
+
__pycache__/
|
11 |
+
*.pyc
|
12 |
+
*.pyo
|
13 |
+
*.pyd
|
14 |
+
*.pyc
|
15 |
+
*.pyw
|
16 |
+
*.pyz
|
17 |
+
|
18 |
+
# OS / IDE files
|
19 |
+
.DS_Store
|
20 |
+
.vscode/
|
21 |
+
.idea/
|
22 |
+
*.swp
|
23 |
+
*.swo
|
24 |
+
*~
|
25 |
+
|
26 |
+
# Git
|
27 |
+
.git/
|
28 |
+
.gitignore
|
29 |
+
|
30 |
+
# Logs & temp
|
31 |
+
*.log
|
32 |
+
*.tmp
|
33 |
+
*.temp
|
34 |
+
.cache/
|
35 |
+
tmp/
|
36 |
+
|
37 |
+
# Données et résultats intermédiaires
|
38 |
+
input/
|
39 |
+
output/
|
40 |
+
*.zip
|
41 |
+
*.png
|
42 |
+
*.jpg
|
43 |
+
*.jpeg
|
44 |
+
*.npy
|
45 |
+
*.pkl
|
46 |
+
*.h5
|
47 |
+
*.pth
|
48 |
+
*.pt
|
49 |
+
|
50 |
+
# Fichiers spécifiques à ton projet
|
51 |
+
data/private/
|
52 |
+
UnderstandScripts/
|
53 |
+
env_test/
|
54 |
+
requirements_copy.txt
|
55 |
+
requirements_clean.txt
|
56 |
+
|
57 |
+
# Vercel & Docker fichiers
|
58 |
+
.vercel/
|
59 |
+
Dockerfile
|
60 |
+
docker-compose.yml
|
61 |
+
.dockerignore
|
62 |
+
|
63 |
+
# Vidéos
|
64 |
+
*.mp4
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
models/
|
.gitignore
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
build/
|
8 |
+
develop-eggs/
|
9 |
+
dist/
|
10 |
+
downloads/
|
11 |
+
eggs/
|
12 |
+
.eggs/
|
13 |
+
lib/
|
14 |
+
lib64/
|
15 |
+
parts/
|
16 |
+
sdist/
|
17 |
+
var/
|
18 |
+
wheels/
|
19 |
+
share/python-wheels/
|
20 |
+
*.egg-info/
|
21 |
+
.installed.cfg
|
22 |
+
*.egg
|
23 |
+
MANIFEST
|
24 |
+
|
25 |
+
# Virtual environments
|
26 |
+
.env
|
27 |
+
.venv
|
28 |
+
env/
|
29 |
+
venv/
|
30 |
+
ENV/
|
31 |
+
env.bak/
|
32 |
+
venv.bak/
|
33 |
+
.venv/
|
34 |
+
|
35 |
+
# IDE
|
36 |
+
.vscode/
|
37 |
+
.idea/
|
38 |
+
*.swp
|
39 |
+
*.swo
|
40 |
+
*~
|
41 |
+
|
42 |
+
# OS
|
43 |
+
.DS_Store
|
44 |
+
.DS_Store?
|
45 |
+
._*
|
46 |
+
.Spotlight-V100
|
47 |
+
.Trashes
|
48 |
+
ehthumbs.db
|
49 |
+
Thumbs.db
|
50 |
+
|
51 |
+
# Logs
|
52 |
+
*.log
|
53 |
+
logs/
|
54 |
+
|
55 |
+
# Temporary files
|
56 |
+
*.tmp
|
57 |
+
*.temp
|
58 |
+
.cache/
|
59 |
+
tmp/
|
60 |
+
|
61 |
+
# Data files (si sensibles)
|
62 |
+
data/private/
|
63 |
+
*.pkl
|
64 |
+
*.h5
|
65 |
+
|
66 |
+
# Model files (si trop gros)
|
67 |
+
*.pth
|
68 |
+
*.pt
|
69 |
+
|
70 |
+
|
71 |
+
# Environment variables
|
72 |
+
.env.local
|
73 |
+
.env.production
|
74 |
+
|
75 |
+
# Jupyter
|
76 |
+
.ipynb_checkpoints/
|
77 |
+
|
78 |
+
# pytest
|
79 |
+
.pytest_cache/
|
80 |
+
|
81 |
+
# Coverage
|
82 |
+
htmlcov/
|
83 |
+
.coverage
|
84 |
+
.coverage.*
|
85 |
+
|
86 |
+
# Docker (optionnel selon votre besoin)
|
87 |
+
# Dockerfile
|
88 |
+
# .dockerignore
|
89 |
+
|
90 |
+
# Vercel
|
91 |
+
.vercel
|
92 |
+
|
93 |
+
input/
|
94 |
+
output/
|
95 |
+
*.zip
|
96 |
+
*.png
|
97 |
+
*.jpg
|
98 |
+
*.npy
|
99 |
+
|
100 |
+
UnderstandScripts/
|
101 |
+
inference/
|
102 |
+
models/
|
103 |
+
|
104 |
+
env_test/
|
105 |
+
|
106 |
+
requirements_copy.txt
|
107 |
+
|
108 |
+
requirements_clean.txt
|
109 |
+
|
Dockerfile
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY requirements.txt .
|
6 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
7 |
+
|
8 |
+
COPY . .
|
9 |
+
|
10 |
+
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8000"]
|
LICENSE
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GNU GENERAL PUBLIC LICENSE
|
2 |
+
Version 2, June 1991
|
3 |
+
|
4 |
+
Copyright (C) 1989, 1991 Free Software Foundation, Inc.,
|
5 |
+
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
6 |
+
Everyone is permitted to copy and distribute verbatim copies
|
7 |
+
of this license document, but changing it is not allowed.
|
8 |
+
|
9 |
+
Preamble
|
10 |
+
|
11 |
+
The licenses for most software are designed to take away your
|
12 |
+
freedom to share and change it. By contrast, the GNU General Public
|
13 |
+
License is intended to guarantee your freedom to share and change free
|
14 |
+
software--to make sure the software is free for all its users. This
|
15 |
+
General Public License applies to most of the Free Software
|
16 |
+
Foundation's software and to any other program whose authors commit to
|
17 |
+
using it. (Some other Free Software Foundation software is covered by
|
18 |
+
the GNU Lesser General Public License instead.) You can apply it to
|
19 |
+
your programs, too.
|
20 |
+
|
21 |
+
When we speak of free software, we are referring to freedom, not
|
22 |
+
price. Our General Public Licenses are designed to make sure that you
|
23 |
+
have the freedom to distribute copies of free software (and charge for
|
24 |
+
this service if you wish), that you receive source code or can get it
|
25 |
+
if you want it, that you can change the software or use pieces of it
|
26 |
+
in new free programs; and that you know you can do these things.
|
27 |
+
|
28 |
+
To protect your rights, we need to make restrictions that forbid
|
29 |
+
anyone to deny you these rights or to ask you to surrender the rights.
|
30 |
+
These restrictions translate to certain responsibilities for you if you
|
31 |
+
distribute copies of the software, or if you modify it.
|
32 |
+
|
33 |
+
For example, if you distribute copies of such a program, whether
|
34 |
+
gratis or for a fee, you must give the recipients all the rights that
|
35 |
+
you have. You must make sure that they, too, receive or can get the
|
36 |
+
source code. And you must show them these terms so they know their
|
37 |
+
rights.
|
38 |
+
|
39 |
+
We protect your rights with two steps: (1) copyright the software, and
|
40 |
+
(2) offer you this license which gives you legal permission to copy,
|
41 |
+
distribute and/or modify the software.
|
42 |
+
|
43 |
+
Also, for each author's protection and ours, we want to make certain
|
44 |
+
that everyone understands that there is no warranty for this free
|
45 |
+
software. If the software is modified by someone else and passed on, we
|
46 |
+
want its recipients to know that what they have is not the original, so
|
47 |
+
that any problems introduced by others will not reflect on the original
|
48 |
+
authors' reputations.
|
49 |
+
|
50 |
+
Finally, any free program is threatened constantly by software
|
51 |
+
patents. We wish to avoid the danger that redistributors of a free
|
52 |
+
program will individually obtain patent licenses, in effect making the
|
53 |
+
program proprietary. To prevent this, we have made it clear that any
|
54 |
+
patent must be licensed for everyone's free use or not licensed at all.
|
55 |
+
|
56 |
+
The precise terms and conditions for copying, distribution and
|
57 |
+
modification follow.
|
58 |
+
|
59 |
+
GNU GENERAL PUBLIC LICENSE
|
60 |
+
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
|
61 |
+
|
62 |
+
0. This License applies to any program or other work which contains
|
63 |
+
a notice placed by the copyright holder saying it may be distributed
|
64 |
+
under the terms of this General Public License. The "Program", below,
|
65 |
+
refers to any such program or work, and a "work based on the Program"
|
66 |
+
means either the Program or any derivative work under copyright law:
|
67 |
+
that is to say, a work containing the Program or a portion of it,
|
68 |
+
either verbatim or with modifications and/or translated into another
|
69 |
+
language. (Hereinafter, translation is included without limitation in
|
70 |
+
the term "modification".) Each licensee is addressed as "you".
|
71 |
+
|
72 |
+
Activities other than copying, distribution and modification are not
|
73 |
+
covered by this License; they are outside its scope. The act of
|
74 |
+
running the Program is not restricted, and the output from the Program
|
75 |
+
is covered only if its contents constitute a work based on the
|
76 |
+
Program (independent of having been made by running the Program).
|
77 |
+
Whether that is true depends on what the Program does.
|
78 |
+
|
79 |
+
1. You may copy and distribute verbatim copies of the Program's
|
80 |
+
source code as you receive it, in any medium, provided that you
|
81 |
+
conspicuously and appropriately publish on each copy an appropriate
|
82 |
+
copyright notice and disclaimer of warranty; keep intact all the
|
83 |
+
notices that refer to this License and to the absence of any warranty;
|
84 |
+
and give any other recipients of the Program a copy of this License
|
85 |
+
along with the Program.
|
86 |
+
|
87 |
+
You may charge a fee for the physical act of transferring a copy, and
|
88 |
+
you may at your option offer warranty protection in exchange for a fee.
|
89 |
+
|
90 |
+
2. You may modify your copy or copies of the Program or any portion
|
91 |
+
of it, thus forming a work based on the Program, and copy and
|
92 |
+
distribute such modifications or work under the terms of Section 1
|
93 |
+
above, provided that you also meet all of these conditions:
|
94 |
+
|
95 |
+
a) You must cause the modified files to carry prominent notices
|
96 |
+
stating that you changed the files and the date of any change.
|
97 |
+
|
98 |
+
b) You must cause any work that you distribute or publish, that in
|
99 |
+
whole or in part contains or is derived from the Program or any
|
100 |
+
part thereof, to be licensed as a whole at no charge to all third
|
101 |
+
parties under the terms of this License.
|
102 |
+
|
103 |
+
c) If the modified program normally reads commands interactively
|
104 |
+
when run, you must cause it, when started running for such
|
105 |
+
interactive use in the most ordinary way, to print or display an
|
106 |
+
announcement including an appropriate copyright notice and a
|
107 |
+
notice that there is no warranty (or else, saying that you provide
|
108 |
+
a warranty) and that users may redistribute the program under
|
109 |
+
these conditions, and telling the user how to view a copy of this
|
110 |
+
License. (Exception: if the Program itself is interactive but
|
111 |
+
does not normally print such an announcement, your work based on
|
112 |
+
the Program is not required to print an announcement.)
|
113 |
+
|
114 |
+
These requirements apply to the modified work as a whole. If
|
115 |
+
identifiable sections of that work are not derived from the Program,
|
116 |
+
and can be reasonably considered independent and separate works in
|
117 |
+
themselves, then this License, and its terms, do not apply to those
|
118 |
+
sections when you distribute them as separate works. But when you
|
119 |
+
distribute the same sections as part of a whole which is a work based
|
120 |
+
on the Program, the distribution of the whole must be on the terms of
|
121 |
+
this License, whose permissions for other licensees extend to the
|
122 |
+
entire whole, and thus to each and every part regardless of who wrote it.
|
123 |
+
|
124 |
+
Thus, it is not the intent of this section to claim rights or contest
|
125 |
+
your rights to work written entirely by you; rather, the intent is to
|
126 |
+
exercise the right to control the distribution of derivative or
|
127 |
+
collective works based on the Program.
|
128 |
+
|
129 |
+
In addition, mere aggregation of another work not based on the Program
|
130 |
+
with the Program (or with a work based on the Program) on a volume of
|
131 |
+
a storage or distribution medium does not bring the other work under
|
132 |
+
the scope of this License.
|
133 |
+
|
134 |
+
3. You may copy and distribute the Program (or a work based on it,
|
135 |
+
under Section 2) in object code or executable form under the terms of
|
136 |
+
Sections 1 and 2 above provided that you also do one of the following:
|
137 |
+
|
138 |
+
a) Accompany it with the complete corresponding machine-readable
|
139 |
+
source code, which must be distributed under the terms of Sections
|
140 |
+
1 and 2 above on a medium customarily used for software interchange; or,
|
141 |
+
|
142 |
+
b) Accompany it with a written offer, valid for at least three
|
143 |
+
years, to give any third party, for a charge no more than your
|
144 |
+
cost of physically performing source distribution, a complete
|
145 |
+
machine-readable copy of the corresponding source code, to be
|
146 |
+
distributed under the terms of Sections 1 and 2 above on a medium
|
147 |
+
customarily used for software interchange; or,
|
148 |
+
|
149 |
+
c) Accompany it with the information you received as to the offer
|
150 |
+
to distribute corresponding source code. (This alternative is
|
151 |
+
allowed only for noncommercial distribution and only if you
|
152 |
+
received the program in object code or executable form with such
|
153 |
+
an offer, in accord with Subsection b above.)
|
154 |
+
|
155 |
+
The source code for a work means the preferred form of the work for
|
156 |
+
making modifications to it. For an executable work, complete source
|
157 |
+
code means all the source code for all modules it contains, plus any
|
158 |
+
associated interface definition files, plus the scripts used to
|
159 |
+
control compilation and installation of the executable. However, as a
|
160 |
+
special exception, the source code distributed need not include
|
161 |
+
anything that is normally distributed (in either source or binary
|
162 |
+
form) with the major components (compiler, kernel, and so on) of the
|
163 |
+
operating system on which the executable runs, unless that component
|
164 |
+
itself accompanies the executable.
|
165 |
+
|
166 |
+
If distribution of executable or object code is made by offering
|
167 |
+
access to copy from a designated place, then offering equivalent
|
168 |
+
access to copy the source code from the same place counts as
|
169 |
+
distribution of the source code, even though third parties are not
|
170 |
+
compelled to copy the source along with the object code.
|
171 |
+
|
172 |
+
4. You may not copy, modify, sublicense, or distribute the Program
|
173 |
+
except as expressly provided under this License. Any attempt
|
174 |
+
otherwise to copy, modify, sublicense or distribute the Program is
|
175 |
+
void, and will automatically terminate your rights under this License.
|
176 |
+
However, parties who have received copies, or rights, from you under
|
177 |
+
this License will not have their licenses terminated so long as such
|
178 |
+
parties remain in full compliance.
|
179 |
+
|
180 |
+
5. You are not required to accept this License, since you have not
|
181 |
+
signed it. However, nothing else grants you permission to modify or
|
182 |
+
distribute the Program or its derivative works. These actions are
|
183 |
+
prohibited by law if you do not accept this License. Therefore, by
|
184 |
+
modifying or distributing the Program (or any work based on the
|
185 |
+
Program), you indicate your acceptance of this License to do so, and
|
186 |
+
all its terms and conditions for copying, distributing or modifying
|
187 |
+
the Program or works based on it.
|
188 |
+
|
189 |
+
6. Each time you redistribute the Program (or any work based on the
|
190 |
+
Program), the recipient automatically receives a license from the
|
191 |
+
original licensor to copy, distribute or modify the Program subject to
|
192 |
+
these terms and conditions. You may not impose any further
|
193 |
+
restrictions on the recipients' exercise of the rights granted herein.
|
194 |
+
You are not responsible for enforcing compliance by third parties to
|
195 |
+
this License.
|
196 |
+
|
197 |
+
7. If, as a consequence of a court judgment or allegation of patent
|
198 |
+
infringement or for any other reason (not limited to patent issues),
|
199 |
+
conditions are imposed on you (whether by court order, agreement or
|
200 |
+
otherwise) that contradict the conditions of this License, they do not
|
201 |
+
excuse you from the conditions of this License. If you cannot
|
202 |
+
distribute so as to satisfy simultaneously your obligations under this
|
203 |
+
License and any other pertinent obligations, then as a consequence you
|
204 |
+
may not distribute the Program at all. For example, if a patent
|
205 |
+
license would not permit royalty-free redistribution of the Program by
|
206 |
+
all those who receive copies directly or indirectly through you, then
|
207 |
+
the only way you could satisfy both it and this License would be to
|
208 |
+
refrain entirely from distribution of the Program.
|
209 |
+
|
210 |
+
If any portion of this section is held invalid or unenforceable under
|
211 |
+
any particular circumstance, the balance of the section is intended to
|
212 |
+
apply and the section as a whole is intended to apply in other
|
213 |
+
circumstances.
|
214 |
+
|
215 |
+
It is not the purpose of this section to induce you to infringe any
|
216 |
+
patents or other property right claims or to contest validity of any
|
217 |
+
such claims; this section has the sole purpose of protecting the
|
218 |
+
integrity of the free software distribution system, which is
|
219 |
+
implemented by public license practices. Many people have made
|
220 |
+
generous contributions to the wide range of software distributed
|
221 |
+
through that system in reliance on consistent application of that
|
222 |
+
system; it is up to the author/donor to decide if he or she is willing
|
223 |
+
to distribute software through any other system and a licensee cannot
|
224 |
+
impose that choice.
|
225 |
+
|
226 |
+
This section is intended to make thoroughly clear what is believed to
|
227 |
+
be a consequence of the rest of this License.
|
228 |
+
|
229 |
+
8. If the distribution and/or use of the Program is restricted in
|
230 |
+
certain countries either by patents or by copyrighted interfaces, the
|
231 |
+
original copyright holder who places the Program under this License
|
232 |
+
may add an explicit geographical distribution limitation excluding
|
233 |
+
those countries, so that distribution is permitted only in or among
|
234 |
+
countries not thus excluded. In such case, this License incorporates
|
235 |
+
the limitation as if written in the body of this License.
|
236 |
+
|
237 |
+
9. The Free Software Foundation may publish revised and/or new versions
|
238 |
+
of the General Public License from time to time. Such new versions will
|
239 |
+
be similar in spirit to the present version, but may differ in detail to
|
240 |
+
address new problems or concerns.
|
241 |
+
|
242 |
+
Each version is given a distinguishing version number. If the Program
|
243 |
+
specifies a version number of this License which applies to it and "any
|
244 |
+
later version", you have the option of following the terms and conditions
|
245 |
+
either of that version or of any later version published by the Free
|
246 |
+
Software Foundation. If the Program does not specify a version number of
|
247 |
+
this License, you may choose any version ever published by the Free Software
|
248 |
+
Foundation.
|
249 |
+
|
250 |
+
10. If you wish to incorporate parts of the Program into other free
|
251 |
+
programs whose distribution conditions are different, write to the author
|
252 |
+
to ask for permission. For software which is copyrighted by the Free
|
253 |
+
Software Foundation, write to the Free Software Foundation; we sometimes
|
254 |
+
make exceptions for this. Our decision will be guided by the two goals
|
255 |
+
of preserving the free status of all derivatives of our free software and
|
256 |
+
of promoting the sharing and reuse of software generally.
|
257 |
+
|
258 |
+
NO WARRANTY
|
259 |
+
|
260 |
+
11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY
|
261 |
+
FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN
|
262 |
+
OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES
|
263 |
+
PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED
|
264 |
+
OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
265 |
+
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS
|
266 |
+
TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE
|
267 |
+
PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,
|
268 |
+
REPAIR OR CORRECTION.
|
269 |
+
|
270 |
+
12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
271 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR
|
272 |
+
REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES,
|
273 |
+
INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING
|
274 |
+
OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED
|
275 |
+
TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY
|
276 |
+
YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER
|
277 |
+
PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE
|
278 |
+
POSSIBILITY OF SUCH DAMAGES.
|
279 |
+
|
280 |
+
END OF TERMS AND CONDITIONS
|
281 |
+
|
282 |
+
How to Apply These Terms to Your New Programs
|
283 |
+
|
284 |
+
If you develop a new program, and you want it to be of the greatest
|
285 |
+
possible use to the public, the best way to achieve this is to make it
|
286 |
+
free software which everyone can redistribute and change under these terms.
|
287 |
+
|
288 |
+
To do so, attach the following notices to the program. It is safest
|
289 |
+
to attach them to the start of each source file to most effectively
|
290 |
+
convey the exclusion of warranty; and each file should have at least
|
291 |
+
the "copyright" line and a pointer to where the full notice is found.
|
292 |
+
|
293 |
+
<one line to give the program's name and a brief idea of what it does.>
|
294 |
+
Copyright (C) <year> <name of author>
|
295 |
+
|
296 |
+
This program is free software; you can redistribute it and/or modify
|
297 |
+
it under the terms of the GNU General Public License as published by
|
298 |
+
the Free Software Foundation; either version 2 of the License, or
|
299 |
+
(at your option) any later version.
|
300 |
+
|
301 |
+
This program is distributed in the hope that it will be useful,
|
302 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
303 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
304 |
+
GNU General Public License for more details.
|
305 |
+
|
306 |
+
You should have received a copy of the GNU General Public License along
|
307 |
+
with this program; if not, write to the Free Software Foundation, Inc.,
|
308 |
+
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
309 |
+
|
310 |
+
Also add information on how to contact you by electronic and paper mail.
|
311 |
+
|
312 |
+
If the program is interactive, make it output a short notice like this
|
313 |
+
when it starts in an interactive mode:
|
314 |
+
|
315 |
+
Gnomovision version 69, Copyright (C) year name of author
|
316 |
+
Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
317 |
+
This is free software, and you are welcome to redistribute it
|
318 |
+
under certain conditions; type `show c' for details.
|
319 |
+
|
320 |
+
The hypothetical commands `show w' and `show c' should show the appropriate
|
321 |
+
parts of the General Public License. Of course, the commands you use may
|
322 |
+
be called something other than `show w' and `show c'; they could even be
|
323 |
+
mouse-clicks or menu items--whatever suits your program.
|
324 |
+
|
325 |
+
You should also get your employer (if you work as a programmer) or your
|
326 |
+
school, if any, to sign a "copyright disclaimer" for the program, if
|
327 |
+
necessary. Here is a sample; alter the names:
|
328 |
+
|
329 |
+
Yoyodyne, Inc., hereby disclaims all copyright interest in the program
|
330 |
+
`Gnomovision' (which makes passes at compilers) written by James Hacker.
|
331 |
+
|
332 |
+
<signature of Ty Coon>, 1 April 1989
|
333 |
+
Ty Coon, President of Vice
|
334 |
+
|
335 |
+
This General Public License does not permit incorporating your program into
|
336 |
+
proprietary programs. If your program is a subroutine library, you may
|
337 |
+
consider it more useful to permit linking proprietary applications with the
|
338 |
+
library. If this is what you want to do, use the GNU Lesser General
|
339 |
+
Public License instead of this License.
|
PnLCalib.yml
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: PnLCalib
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- conda-forge
|
6 |
+
- defaults
|
7 |
+
dependencies:
|
8 |
+
- _libgcc_mutex=0.1=conda_forge
|
9 |
+
- _openmp_mutex=4.5=2_kmp_llvm
|
10 |
+
- alsa-lib=1.2.8=h166bdaf_0
|
11 |
+
- attr=2.5.1=h166bdaf_1
|
12 |
+
- blas=2.116=mkl
|
13 |
+
- blas-devel=3.9.0=16_linux64_mkl
|
14 |
+
- brotli=1.1.0=hd590300_1
|
15 |
+
- brotli-bin=1.1.0=hd590300_1
|
16 |
+
- brotli-python=1.1.0=py39h3d6467e_1
|
17 |
+
- bzip2=1.0.8=hd590300_5
|
18 |
+
- ca-certificates=2024.6.2=hbcca054_0
|
19 |
+
- cairo=1.16.0=ha61ee94_1014
|
20 |
+
- certifi=2024.6.2=pyhd8ed1ab_0
|
21 |
+
- cffi=1.16.0=py39h7a31438_0
|
22 |
+
- charset-normalizer=3.3.2=pyhd8ed1ab_0
|
23 |
+
- colorama=0.4.6=pyhd8ed1ab_0
|
24 |
+
- contourpy=1.2.1=py39h7633fee_0
|
25 |
+
- cuda-cudart=12.1.105=0
|
26 |
+
- cuda-cupti=12.1.105=0
|
27 |
+
- cuda-libraries=12.1.0=0
|
28 |
+
- cuda-nvrtc=12.1.105=0
|
29 |
+
- cuda-nvtx=12.1.105=0
|
30 |
+
- cuda-opencl=12.5.39=0
|
31 |
+
- cuda-runtime=12.1.0=0
|
32 |
+
- cuda-version=12.5=3
|
33 |
+
- cycler=0.12.1=pyhd8ed1ab_0
|
34 |
+
- dbus=1.13.6=h5008d03_3
|
35 |
+
- expat=2.6.2=h59595ed_0
|
36 |
+
- ffmpeg=4.3=hf484d3e_0
|
37 |
+
- fftw=3.3.10=nompi_hf1063bd_110
|
38 |
+
- filelock=3.15.4=pyhd8ed1ab_0
|
39 |
+
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
|
40 |
+
- font-ttf-inconsolata=3.000=h77eed37_0
|
41 |
+
- font-ttf-source-code-pro=2.038=h77eed37_0
|
42 |
+
- font-ttf-ubuntu=0.83=h77eed37_2
|
43 |
+
- fontconfig=2.14.2=h14ed4e7_0
|
44 |
+
- fonts-conda-ecosystem=1=0
|
45 |
+
- fonts-conda-forge=1=0
|
46 |
+
- fonttools=4.53.0=py39hd3abc70_0
|
47 |
+
- freetype=2.12.1=h267a509_2
|
48 |
+
- geos=3.12.1=h59595ed_0
|
49 |
+
- gettext=0.22.5=h59595ed_2
|
50 |
+
- gettext-tools=0.22.5=h59595ed_2
|
51 |
+
- glib=2.80.2=hf974151_0
|
52 |
+
- glib-tools=2.80.2=hb6ce0ca_0
|
53 |
+
- gmp=6.3.0=hac33072_2
|
54 |
+
- gmpy2=2.1.5=py39h048c657_1
|
55 |
+
- gnutls=3.6.13=h85f3911_1
|
56 |
+
- graphite2=1.3.13=h59595ed_1003
|
57 |
+
- gst-plugins-base=1.22.0=h4243ec0_2
|
58 |
+
- gstreamer=1.22.0=h25f0c4b_2
|
59 |
+
- gstreamer-orc=0.4.38=hd590300_0
|
60 |
+
- h2=4.1.0=pyhd8ed1ab_0
|
61 |
+
- harfbuzz=6.0.0=h8e241bc_0
|
62 |
+
- hpack=4.0.0=pyh9f0ad1d_0
|
63 |
+
- hyperframe=6.0.1=pyhd8ed1ab_0
|
64 |
+
- icu=70.1=h27087fc_0
|
65 |
+
- idna=3.7=pyhd8ed1ab_0
|
66 |
+
- importlib-resources=6.4.0=pyhd8ed1ab_0
|
67 |
+
- importlib_resources=6.4.0=pyhd8ed1ab_0
|
68 |
+
- jack=1.9.22=h11f4161_0
|
69 |
+
- jinja2=3.1.4=pyhd8ed1ab_0
|
70 |
+
- jpeg=9e=h166bdaf_2
|
71 |
+
- keyutils=1.6.1=h166bdaf_0
|
72 |
+
- kiwisolver=1.4.5=py39h7633fee_1
|
73 |
+
- krb5=1.20.1=h81ceb04_0
|
74 |
+
- lame=3.100=h166bdaf_1003
|
75 |
+
- lcms2=2.15=hfd0df8a_0
|
76 |
+
- ld_impl_linux-64=2.40=hf3520f5_7
|
77 |
+
- lerc=4.0.0=h27087fc_0
|
78 |
+
- libasprintf=0.22.5=h661eb56_2
|
79 |
+
- libasprintf-devel=0.22.5=h661eb56_2
|
80 |
+
- libblas=3.9.0=16_linux64_mkl
|
81 |
+
- libbrotlicommon=1.1.0=hd590300_1
|
82 |
+
- libbrotlidec=1.1.0=hd590300_1
|
83 |
+
- libbrotlienc=1.1.0=hd590300_1
|
84 |
+
- libcap=2.67=he9d0100_0
|
85 |
+
- libcblas=3.9.0=16_linux64_mkl
|
86 |
+
- libclang=15.0.7=default_h127d8a8_5
|
87 |
+
- libclang13=15.0.7=default_h5d6823c_5
|
88 |
+
- libcublas=12.1.0.26=0
|
89 |
+
- libcufft=11.0.2.4=0
|
90 |
+
- libcufile=1.10.0.4=0
|
91 |
+
- libcups=2.3.3=h36d4200_3
|
92 |
+
- libcurand=10.3.6.39=0
|
93 |
+
- libcusolver=11.4.4.55=0
|
94 |
+
- libcusparse=12.0.2.55=0
|
95 |
+
- libdb=6.2.32=h9c3ff4c_0
|
96 |
+
- libdeflate=1.17=h0b41bf4_0
|
97 |
+
- libedit=3.1.20191231=he28a2e2_2
|
98 |
+
- libevent=2.1.10=h28343ad_4
|
99 |
+
- libexpat=2.6.2=h59595ed_0
|
100 |
+
- libffi=3.4.2=h7f98852_5
|
101 |
+
- libflac=1.4.3=h59595ed_0
|
102 |
+
- libgcc-ng=14.1.0=h77fa898_0
|
103 |
+
- libgcrypt=1.11.0=h4ab18f5_0
|
104 |
+
- libgettextpo=0.22.5=h59595ed_2
|
105 |
+
- libgettextpo-devel=0.22.5=h59595ed_2
|
106 |
+
- libgfortran-ng=14.1.0=h69a702a_0
|
107 |
+
- libgfortran5=14.1.0=hc5f4f2c_0
|
108 |
+
- libglib=2.80.2=hf974151_0
|
109 |
+
- libgomp=14.1.0=h77fa898_0
|
110 |
+
- libgpg-error=1.50=h4f305b6_0
|
111 |
+
- libhwloc=2.9.1=hd6dc26d_0
|
112 |
+
- libiconv=1.17=hd590300_2
|
113 |
+
- libjpeg-turbo=2.0.0=h9bf148f_0
|
114 |
+
- liblapack=3.9.0=16_linux64_mkl
|
115 |
+
- liblapacke=3.9.0=16_linux64_mkl
|
116 |
+
- libllvm15=15.0.7=hadd5161_1
|
117 |
+
- libnpp=12.0.2.50=0
|
118 |
+
- libnsl=2.0.1=hd590300_0
|
119 |
+
- libnvjitlink=12.1.105=0
|
120 |
+
- libnvjpeg=12.1.1.14=0
|
121 |
+
- libogg=1.3.5=h4ab18f5_0
|
122 |
+
- libopus=1.3.1=h7f98852_1
|
123 |
+
- libpng=1.6.43=h2797004_0
|
124 |
+
- libpq=15.3=hbcd7760_1
|
125 |
+
- libsndfile=1.2.2=hc60ed4a_1
|
126 |
+
- libsqlite=3.46.0=hde9e2c9_0
|
127 |
+
- libstdcxx-ng=14.1.0=hc0a3c3a_0
|
128 |
+
- libsystemd0=253=h8c4010b_1
|
129 |
+
- libtiff=4.5.0=h6adf6a1_2
|
130 |
+
- libtool=2.4.7=h27087fc_0
|
131 |
+
- libudev1=253=h0b41bf4_1
|
132 |
+
- libuuid=2.38.1=h0b41bf4_0
|
133 |
+
- libvorbis=1.3.7=h9c3ff4c_0
|
134 |
+
- libwebp-base=1.4.0=hd590300_0
|
135 |
+
- libxcb=1.13=h7f98852_1004
|
136 |
+
- libxcrypt=4.4.36=hd590300_1
|
137 |
+
- libxkbcommon=1.5.0=h79f4944_1
|
138 |
+
- libxml2=2.10.3=hca2bb57_4
|
139 |
+
- libzlib=1.2.13=h4ab18f5_6
|
140 |
+
- llvm-openmp=15.0.7=h0cdce71_0
|
141 |
+
- lz4-c=1.9.4=hcb278e6_0
|
142 |
+
- markupsafe=2.1.5=py39hd1e30aa_0
|
143 |
+
- matplotlib=3.8.4=py39hf3d152e_2
|
144 |
+
- matplotlib-base=3.8.4=py39h10d1fc8_2
|
145 |
+
- mkl=2022.1.0=h84fe81f_915
|
146 |
+
- mkl-devel=2022.1.0=ha770c72_916
|
147 |
+
- mkl-include=2022.1.0=h84fe81f_915
|
148 |
+
- mpc=1.3.1=hfe3b2da_0
|
149 |
+
- mpfr=4.2.1=h9458935_1
|
150 |
+
- mpg123=1.32.6=h59595ed_0
|
151 |
+
- mpmath=1.3.0=pyhd8ed1ab_0
|
152 |
+
- munkres=1.1.4=pyh9f0ad1d_0
|
153 |
+
- mysql-common=8.0.33=hf1915f5_6
|
154 |
+
- mysql-libs=8.0.33=hca2cd23_6
|
155 |
+
- ncurses=6.5=h59595ed_0
|
156 |
+
- nettle=3.6=he412f7d_0
|
157 |
+
- networkx=3.2.1=pyhd8ed1ab_0
|
158 |
+
- nspr=4.35=h27087fc_0
|
159 |
+
- nss=3.100=hca3bf56_0
|
160 |
+
- numpy=2.0.0=py39ha0965c0_0
|
161 |
+
- openh264=2.1.1=h780b84a_0
|
162 |
+
- openjpeg=2.5.0=hfec8fc6_2
|
163 |
+
- openssl=3.1.6=h4ab18f5_0
|
164 |
+
- packaging=24.1=pyhd8ed1ab_0
|
165 |
+
- pcre2=10.43=hcad00b1_0
|
166 |
+
- pillow=9.4.0=py39h2320bf1_1
|
167 |
+
- pip=24.0=pyhd8ed1ab_0
|
168 |
+
- pixman=0.43.2=h59595ed_0
|
169 |
+
- ply=3.11=pyhd8ed1ab_2
|
170 |
+
- pthread-stubs=0.4=h36c2ea0_1001
|
171 |
+
- pulseaudio=16.1=hcb278e6_3
|
172 |
+
- pulseaudio-client=16.1=h5195f5e_3
|
173 |
+
- pulseaudio-daemon=16.1=ha8d29e2_3
|
174 |
+
- pycparser=2.22=pyhd8ed1ab_0
|
175 |
+
- pyparsing=3.1.2=pyhd8ed1ab_0
|
176 |
+
- pyqt=5.15.9=py39h52134e7_5
|
177 |
+
- pyqt5-sip=12.12.2=py39h3d6467e_5
|
178 |
+
- pysocks=1.7.1=pyha2e5f31_6
|
179 |
+
- python=3.9.18=h0755675_0_cpython
|
180 |
+
- python-dateutil=2.9.0=pyhd8ed1ab_0
|
181 |
+
- python_abi=3.9=4_cp39
|
182 |
+
- pytorch=2.3.1=py3.9_cuda12.1_cudnn8.9.2_0
|
183 |
+
- pytorch-cuda=12.1=ha16c6d3_5
|
184 |
+
- pytorch-mutex=1.0=cuda
|
185 |
+
- pyyaml=6.0.1=py39hd1e30aa_1
|
186 |
+
- qt-main=5.15.8=h5d23da1_6
|
187 |
+
- readline=8.2=h8228510_1
|
188 |
+
- requests=2.32.3=pyhd8ed1ab_0
|
189 |
+
- setuptools=70.1.1=pyhd8ed1ab_0
|
190 |
+
- shapely=2.0.4=py39h5a575da_1
|
191 |
+
- sip=6.7.12=py39h3d6467e_0
|
192 |
+
- six=1.16.0=pyh6c4a22f_0
|
193 |
+
- sympy=1.12.1=pypyh2585a3b_103
|
194 |
+
- tbb=2021.9.0=hf52228f_0
|
195 |
+
- tk=8.6.13=noxft_h4845f30_101
|
196 |
+
- toml=0.10.2=pyhd8ed1ab_0
|
197 |
+
- tomli=2.0.1=pyhd8ed1ab_0
|
198 |
+
- torchaudio=2.3.1=py39_cu121
|
199 |
+
- torchtriton=2.3.1=py39
|
200 |
+
- torchvision=0.18.1=py39_cu121
|
201 |
+
- tornado=6.4.1=py39hd3abc70_0
|
202 |
+
- tqdm=4.66.4=pyhd8ed1ab_0
|
203 |
+
- typing_extensions=4.12.2=pyha770c72_0
|
204 |
+
- tzdata=2024a=h0c530f3_0
|
205 |
+
- unicodedata2=15.1.0=py39hd1e30aa_0
|
206 |
+
- urllib3=2.2.2=pyhd8ed1ab_1
|
207 |
+
- wheel=0.43.0=pyhd8ed1ab_1
|
208 |
+
- xcb-util=0.4.0=h516909a_0
|
209 |
+
- xcb-util-image=0.4.0=h166bdaf_0
|
210 |
+
- xcb-util-keysyms=0.4.0=h516909a_0
|
211 |
+
- xcb-util-renderutil=0.3.9=h166bdaf_0
|
212 |
+
- xcb-util-wm=0.4.1=h516909a_0
|
213 |
+
- xkeyboard-config=2.38=h0b41bf4_0
|
214 |
+
- xorg-kbproto=1.0.7=h7f98852_1002
|
215 |
+
- xorg-libice=1.1.1=hd590300_0
|
216 |
+
- xorg-libsm=1.2.4=h7391055_0
|
217 |
+
- xorg-libx11=1.8.4=h0b41bf4_0
|
218 |
+
- xorg-libxau=1.0.11=hd590300_0
|
219 |
+
- xorg-libxdmcp=1.1.3=h7f98852_0
|
220 |
+
- xorg-libxext=1.3.4=h0b41bf4_2
|
221 |
+
- xorg-libxrender=0.9.10=h7f98852_1003
|
222 |
+
- xorg-renderproto=0.11.1=h7f98852_1002
|
223 |
+
- xorg-xextproto=7.3.0=h0b41bf4_1003
|
224 |
+
- xorg-xproto=7.0.31=h7f98852_1007
|
225 |
+
- xz=5.2.6=h166bdaf_0
|
226 |
+
- yaml=0.2.5=h7f98852_2
|
227 |
+
- zipp=3.19.2=pyhd8ed1ab_0
|
228 |
+
- zlib=1.2.13=h4ab18f5_6
|
229 |
+
- zstandard=0.22.0=py39h81c9582_1
|
230 |
+
- zstd=1.5.6=ha6fb4c9_0
|
231 |
+
- pip:
|
232 |
+
- lsq-ellipse==2.2.1
|
233 |
+
- opencv-python==4.10.0.84
|
234 |
+
- scipy==1.13.1
|
235 |
+
prefix: /home/mgutierrez/anaconda3/envs/PnLCalib
|
README.md
CHANGED
@@ -1,10 +1,134 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PnLCalib API
|
2 |
+
|
3 |
+
> API REST pour la calibration de caméras à partir de lignes de terrain de football
|
4 |
+
|
5 |
+
## À propos de ce projet
|
6 |
+
|
7 |
+
Cette API est basée sur le travail de recherche original de **SoccerNet Camera Calibration Challenge**. Elle transforme les algorithmes de calibration existants en une API REST accessible et facile à utiliser.
|
8 |
+
|
9 |
+
### Travail original
|
10 |
+
📍 **Repository source** : [SoccerNet Camera Calibration](https://github.com/SoccerNet/sn-calibration) [Marc Gutiérrez-Pérez](https://github.com/mguti97/PnLCalib)
|
11 |
+
📖 **Paper** : SoccerNet Camera Calibration Challenge
|
12 |
+
👥 **Auteurs** : Équipe SoccerNet
|
13 |
+
|
14 |
+
## Fonctionnalités
|
15 |
+
|
16 |
+
✅ Calibration automatique de caméras à partir d'images de terrain de football
|
17 |
+
✅ API REST avec FastAPI
|
18 |
+
✅ Support des formats d'image : JPG, PNG
|
19 |
+
|
20 |
+
## Installation locale
|
21 |
+
|
22 |
+
```bash
|
23 |
+
# Cloner le repository
|
24 |
+
git clone https://github.com/2nzi/PnLCalib.git
|
25 |
+
cd PnLCalib
|
26 |
+
|
27 |
+
# Installer les dépendances
|
28 |
+
pip install -r requirements.txt
|
29 |
+
|
30 |
+
# Lancer l'API
|
31 |
+
python run_api.py
|
32 |
+
```
|
33 |
+
|
34 |
+
L'API sera accessible sur : http://localhost:8000
|
35 |
+
|
36 |
+
## Utilisation
|
37 |
+
|
38 |
+
### Endpoint principal : `/calibrate`
|
39 |
+
|
40 |
+
**POST** `/calibrate` - Calibrer une caméra à partir d'une image et de lignes du terrain
|
41 |
+
|
42 |
+
**Paramètres :**
|
43 |
+
- `image` : Fichier image (multipart/form-data)
|
44 |
+
- `lines_data` : JSON des lignes du terrain (string)
|
45 |
+
|
46 |
+
### Exemple d'utilisation
|
47 |
+
|
48 |
+
#### Avec JavaScript/Fetch
|
49 |
+
```javascript
|
50 |
+
const formData = new FormData();
|
51 |
+
formData.append('image', imageFile);
|
52 |
+
formData.append('lines_data', JSON.stringify({
|
53 |
+
"Big rect. right top": [
|
54 |
+
{"x": 1342.88, "y": 1076.99},
|
55 |
+
{"x": 1484.74, "y": 906.37}
|
56 |
+
],
|
57 |
+
"Big rect. right main": [
|
58 |
+
{"x": 1484.74, "y": 906.37},
|
59 |
+
{"x": 1049.62, "y": 748.02}
|
60 |
+
],
|
61 |
+
"Circle central": [
|
62 |
+
{"x": 1580.73, "y": 269.84},
|
63 |
+
{"x": 1533.83, "y": 288.86}
|
64 |
+
]
|
65 |
+
// ... autres lignes
|
66 |
+
}));
|
67 |
+
|
68 |
+
const response = await fetch('http://localhost:8000/calibrate', {
|
69 |
+
method: 'POST',
|
70 |
+
body: formData
|
71 |
+
});
|
72 |
+
|
73 |
+
const result = await response.json();
|
74 |
+
console.log('Paramètres de calibration:', result.camera_parameters);
|
75 |
+
```
|
76 |
+
|
77 |
+
#### Avec curl
|
78 |
+
```bash
|
79 |
+
curl -X POST "http://localhost:8000/calibrate" \
|
80 |
+
-F "[email protected]" \
|
81 |
+
-F 'lines_data={"Big rect. right top":[{"x":1342.88,"y":1076.99}]}'
|
82 |
+
```
|
83 |
+
|
84 |
+
### Format de réponse
|
85 |
+
|
86 |
+
```json
|
87 |
+
{
|
88 |
+
"status": "success",
|
89 |
+
"camera_parameters": {
|
90 |
+
"pan_degrees": -45.2,
|
91 |
+
"tilt_degrees": 12.8,
|
92 |
+
"roll_degrees": 1.2,
|
93 |
+
"position_meters": [10.5, 20.3, 5.8],
|
94 |
+
"x_focal_length": 1200.5,
|
95 |
+
"y_focal_length": 1201.2,
|
96 |
+
"principal_point": [960, 540]
|
97 |
+
},
|
98 |
+
"input_lines": { /* lignes validées */ },
|
99 |
+
"message": "Calibration réussie"
|
100 |
+
}
|
101 |
+
```
|
102 |
+
|
103 |
+
## Documentation
|
104 |
+
|
105 |
+
Une fois l'API lancée, accédez à la documentation interactive :
|
106 |
+
- **Swagger UI** : http://localhost:8000/docs
|
107 |
+
- **ReDoc** : http://localhost:8000/redoc
|
108 |
+
|
109 |
+
## Health Check
|
110 |
+
|
111 |
+
```bash
|
112 |
+
curl http://localhost:8000/health
|
113 |
+
```
|
114 |
+
|
115 |
+
## Support des lignes de terrain
|
116 |
+
|
117 |
+
L'API accepte ces types de lignes de terrain :
|
118 |
+
- `Big rect. right/left top/main/bottom`
|
119 |
+
- `Small rect. right/left top/main`
|
120 |
+
- `Circle right/left/central`
|
121 |
+
- `Side line bottom/left`
|
122 |
+
- `Middle line`
|
123 |
+
...
|
124 |
+
|
125 |
+
Chaque ligne est définie par une liste de points avec coordonnées `x` et `y`.
|
126 |
+
|
127 |
+
## Crédits
|
128 |
+
|
129 |
+
Basé sur le travail original de l'équipe SoccerNet pour le Camera Calibration Challenge [Marc Gutiérrez-Pérez](https://github.com/mguti97/PnLCalib).
|
130 |
+
Transformé en API REST par [2nzi](https://github.com/2nzi).
|
131 |
+
|
132 |
+
## Licence
|
133 |
+
|
134 |
+
Voir [LICENSE](LICENSE) - Basé sur la licence du projet original SoccerNet.
|
api.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from typing import Dict, List, Any
|
5 |
+
import json
|
6 |
+
import tempfile
|
7 |
+
import os
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from get_camera_params import get_camera_parameters
|
12 |
+
|
13 |
+
app = FastAPI(
|
14 |
+
title="Football Vision Calibration API",
|
15 |
+
description="API pour la calibration de caméras à partir de lignes de terrain de football",
|
16 |
+
version="1.0.0"
|
17 |
+
)
|
18 |
+
|
19 |
+
# Configuration CORS pour autoriser les requêtes depuis le frontend
|
20 |
+
app.add_middleware(
|
21 |
+
CORSMiddleware,
|
22 |
+
allow_origins=["*"], # En production, spécifiez les domaines autorisés
|
23 |
+
allow_credentials=True,
|
24 |
+
allow_methods=["*"],
|
25 |
+
allow_headers=["*"],
|
26 |
+
)
|
27 |
+
|
28 |
+
# Modèles Pydantic pour la validation des données
|
29 |
+
class Point(BaseModel):
|
30 |
+
x: float
|
31 |
+
y: float
|
32 |
+
|
33 |
+
class LinePolygon(BaseModel):
|
34 |
+
points: List[Point]
|
35 |
+
|
36 |
+
class CalibrationRequest(BaseModel):
|
37 |
+
lines: Dict[str, List[Point]]
|
38 |
+
|
39 |
+
class CalibrationResponse(BaseModel):
|
40 |
+
status: str
|
41 |
+
camera_parameters: Dict[str, Any]
|
42 |
+
input_lines: Dict[str, List[Point]]
|
43 |
+
message: str
|
44 |
+
|
45 |
+
@app.get("/")
|
46 |
+
async def root():
|
47 |
+
return {
|
48 |
+
"message": "Football Vision Calibration API",
|
49 |
+
"version": "1.0.0",
|
50 |
+
"endpoints": {
|
51 |
+
"/calibrate": "POST - Calibrer une caméra à partir d'une image et de lignes",
|
52 |
+
"/health": "GET - Vérifier l'état de l'API"
|
53 |
+
}
|
54 |
+
}
|
55 |
+
|
56 |
+
@app.get("/health")
|
57 |
+
async def health_check():
|
58 |
+
return {"status": "healthy", "message": "API is running"}
|
59 |
+
|
60 |
+
@app.post("/calibrate", response_model=CalibrationResponse)
|
61 |
+
async def calibrate_camera(
|
62 |
+
image: UploadFile = File(..., description="Image du terrain de football"),
|
63 |
+
lines_data: str = Form(..., description="JSON des lignes du terrain")
|
64 |
+
):
|
65 |
+
"""
|
66 |
+
Calibrer une caméra à partir d'une image et des lignes du terrain.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
image: Image du terrain de football (formats: jpg, jpeg, png)
|
70 |
+
lines_data: JSON contenant les lignes du terrain au format:
|
71 |
+
{"nom_ligne": [{"x": float, "y": float}, ...], ...}
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
Paramètres de calibration de la caméra et lignes d'entrée
|
75 |
+
"""
|
76 |
+
try:
|
77 |
+
# Validation du format d'image
|
78 |
+
if not image.content_type.startswith('image/'):
|
79 |
+
raise HTTPException(status_code=400, detail="Le fichier doit être une image")
|
80 |
+
|
81 |
+
# Parse des données de lignes
|
82 |
+
try:
|
83 |
+
lines_dict = json.loads(lines_data)
|
84 |
+
except json.JSONDecodeError:
|
85 |
+
raise HTTPException(status_code=400, detail="Format JSON invalide pour les lignes")
|
86 |
+
|
87 |
+
# Validation de la structure des lignes
|
88 |
+
validated_lines = {}
|
89 |
+
for line_name, points in lines_dict.items():
|
90 |
+
if not isinstance(points, list):
|
91 |
+
raise HTTPException(
|
92 |
+
status_code=400,
|
93 |
+
detail=f"Les points de la ligne '{line_name}' doivent être une liste"
|
94 |
+
)
|
95 |
+
|
96 |
+
validated_points = []
|
97 |
+
for i, point in enumerate(points):
|
98 |
+
if not isinstance(point, dict) or 'x' not in point or 'y' not in point:
|
99 |
+
raise HTTPException(
|
100 |
+
status_code=400,
|
101 |
+
detail=f"Point {i} de la ligne '{line_name}' doit avoir les clés 'x' et 'y'"
|
102 |
+
)
|
103 |
+
try:
|
104 |
+
validated_points.append({
|
105 |
+
"x": float(point['x']),
|
106 |
+
"y": float(point['y'])
|
107 |
+
})
|
108 |
+
except (ValueError, TypeError):
|
109 |
+
raise HTTPException(
|
110 |
+
status_code=400,
|
111 |
+
detail=f"Coordonnées invalides pour le point {i} de la ligne '{line_name}'"
|
112 |
+
)
|
113 |
+
|
114 |
+
validated_lines[line_name] = validated_points
|
115 |
+
|
116 |
+
# Sauvegarde temporaire de l'image
|
117 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{image.filename.split('.')[-1]}") as temp_file:
|
118 |
+
content = await image.read()
|
119 |
+
temp_file.write(content)
|
120 |
+
temp_image_path = temp_file.name
|
121 |
+
|
122 |
+
try:
|
123 |
+
# Validation de l'image
|
124 |
+
pil_image = Image.open(temp_image_path)
|
125 |
+
pil_image.verify() # Vérification de l'intégrité de l'image
|
126 |
+
|
127 |
+
# Calibration de la caméra
|
128 |
+
camera_params = get_camera_parameters(temp_image_path, validated_lines)
|
129 |
+
|
130 |
+
# Formatage de la réponse
|
131 |
+
response = CalibrationResponse(
|
132 |
+
status="success",
|
133 |
+
camera_parameters=camera_params,
|
134 |
+
input_lines=validated_lines,
|
135 |
+
message="Calibration réussie"
|
136 |
+
)
|
137 |
+
|
138 |
+
return response
|
139 |
+
|
140 |
+
except Exception as e:
|
141 |
+
raise HTTPException(
|
142 |
+
status_code=500,
|
143 |
+
detail=f"Erreur lors de la calibration: {str(e)}"
|
144 |
+
)
|
145 |
+
|
146 |
+
finally:
|
147 |
+
# Nettoyage du fichier temporaire
|
148 |
+
if os.path.exists(temp_image_path):
|
149 |
+
os.unlink(temp_image_path)
|
150 |
+
|
151 |
+
except HTTPException:
|
152 |
+
raise
|
153 |
+
except Exception as e:
|
154 |
+
raise HTTPException(status_code=500, detail=f"Erreur interne: {str(e)}")
|
155 |
+
|
156 |
+
# if __name__ == "__main__":
|
157 |
+
# import uvicorn
|
158 |
+
# uvicorn.run(app, host="0.0.0.0", port=8000)
|
159 |
+
|
160 |
+
# Ajoutez ceci à la place :
|
161 |
+
# Point d'entrée pour Vercel
|
162 |
+
app_instance = app
|
config.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# config.yaml
|
2 |
+
# weights_kp: "models/SV_FT_TSWC_kp"
|
3 |
+
# weights_line: "models/SV_FT_TSWC_lines"
|
4 |
+
# pnl_refine: true
|
5 |
+
# device: "cuda:0"
|
6 |
+
# input_path: "examples/input/"
|
7 |
+
# output_path: "examples/output/"
|
8 |
+
# name: "FootDrone.jpg"
|
9 |
+
# input_type: "image"
|
10 |
+
# display: true
|
11 |
+
|
12 |
+
# # Start with low thresholds
|
13 |
+
# thresholds:
|
14 |
+
# - kp_threshold: 0.15
|
15 |
+
# line_threshold: 0.15
|
16 |
+
|
17 |
+
|
18 |
+
weights_kp: "models/SV_FT_TSWC_kp"
|
19 |
+
weights_line: "models/SV_FT_TSWC_lines"
|
20 |
+
device: "cuda:0"
|
21 |
+
input_path: "examples/input/"
|
22 |
+
output_path: "examples/output/"
|
23 |
+
name: "FootDrone.mp4"
|
24 |
+
input_type: "video"
|
25 |
+
display: true
|
26 |
+
pnl_refine: true
|
27 |
+
frame_step: 5 # Traiter une image sur 5
|
28 |
+
two_pass: true
|
29 |
+
|
30 |
+
# Start with low thresholds
|
31 |
+
thresholds:
|
32 |
+
- kp_threshold: 0.15
|
33 |
+
line_threshold: 0.15
|
config/field_config.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Configuration file for football field elements (lines and keypoints)
|
3 |
+
All measurements are in meters
|
4 |
+
Field dimensions: 105m x 68m
|
5 |
+
Origin (0,0) is at top-left corner
|
6 |
+
"""
|
7 |
+
|
8 |
+
# Lines configuration
|
9 |
+
LINES = {
|
10 |
+
1: {"name": "Big rect. left bottom", "description": "Surface de réparation gauche - ligne basse"},
|
11 |
+
2: {"name": "Big rect. left main", "description": "Surface de réparation gauche - ligne parallèle"},
|
12 |
+
3: {"name": "Big rect. left top", "description": "Surface de réparation gauche - ligne haute"},
|
13 |
+
4: {"name": "Big rect. right bottom", "description": "Surface de réparation droite - ligne basse"},
|
14 |
+
5: {"name": "Big rect. right main", "description": "Surface de réparation droite - ligne parallèle"},
|
15 |
+
6: {"name": "Big rect. right top", "description": "Surface de réparation droite - ligne haute"},
|
16 |
+
7: {"name": "Goal left crossbar", "description": "Barre transversale but gauche"},
|
17 |
+
8: {"name": "Goal left post left", "description": "Poteau gauche but gauche"},
|
18 |
+
9: {"name": "Goal left post right", "description": "Poteau droit but gauche"},
|
19 |
+
10: {"name": "Goal right crossbar", "description": "Barre transversale but droit"},
|
20 |
+
11: {"name": "Goal right post left", "description": "Poteau gauche but droit"},
|
21 |
+
12: {"name": "Goal right post right", "description": "Poteau droit but droit"},
|
22 |
+
13: {"name": "Middle line", "description": "Ligne médiane"},
|
23 |
+
14: {"name": "Side line bottom", "description": "Ligne de but"},
|
24 |
+
15: {"name": "Side line left", "description": "Ligne de touche gauche"},
|
25 |
+
16: {"name": "Side line right", "description": "Ligne de touche droite"},
|
26 |
+
17: {"name": "Side line top", "description": "Ligne de but opposée"},
|
27 |
+
18: {"name": "Small rect. left bottom", "description": "Petite surface gauche - ligne basse"},
|
28 |
+
19: {"name": "Small rect. left main", "description": "Petite surface gauche - ligne parallèle"},
|
29 |
+
20: {"name": "Small rect. left top", "description": "Petite surface gauche - ligne haute"},
|
30 |
+
21: {"name": "Small rect. right bottom", "description": "Petite surface droite - ligne basse"},
|
31 |
+
22: {"name": "Small rect. right main", "description": "Petite surface droite - ligne parallèle"},
|
32 |
+
23: {"name": "Small rect. right top", "description": "Petite surface droite - ligne haute"}
|
33 |
+
}
|
34 |
+
|
35 |
+
# Keypoints configuration
|
36 |
+
KEYPOINTS = {
|
37 |
+
# Points principaux (1-30)
|
38 |
+
1: {"name": "Top Left Corner", "coords": [0.0, 0.0], "description": "Coin supérieur gauche"},
|
39 |
+
2: {"name": "Top Middle", "coords": [52.5, 0.0], "description": "Milieu ligne du haut"},
|
40 |
+
3: {"name": "Top Right Corner", "coords": [105.0, 0.0], "description": "Coin supérieur droit"},
|
41 |
+
4: {"name": "Left Big Box Top", "coords": [0.0, 16.5], "description": "Surface gauche haut"},
|
42 |
+
5: {"name": "Left Big Box Main", "coords": [16.5, 16.5], "description": "Surface gauche principale"},
|
43 |
+
6: {"name": "Right Big Box Main", "coords": [88.5, 16.5], "description": "Surface droite principale"},
|
44 |
+
7: {"name": "Right Big Box Top", "coords": [105.0, 16.5], "description": "Surface droite haut"},
|
45 |
+
8: {"name": "Left Small Box Top", "coords": [0.0, 5.5], "description": "Petite surface gauche haut"},
|
46 |
+
9: {"name": "Left Small Box Main", "coords": [5.5, 5.5], "description": "Petite surface gauche principale"},
|
47 |
+
10: {"name": "Right Small Box Main", "coords": [99.5, 5.5], "description": "Petite surface droite principale"},
|
48 |
+
11: {"name": "Right Small Box Top", "coords": [105.0, 5.5], "description": "Petite surface droite haut"},
|
49 |
+
12: {"name": "Left Goal Crossbar Right", "coords": [3.66, 0.0], "description": "Barre transversale gauche - droite"},
|
50 |
+
13: {"name": "Left Goal Post Right", "coords": [3.66, 2.44], "description": "Poteau droit but gauche"},
|
51 |
+
14: {"name": "Right Goal Post Left", "coords": [101.34, 2.44], "description": "Poteau gauche but droit"},
|
52 |
+
15: {"name": "Right Goal Crossbar Left", "coords": [101.34, 0.0], "description": "Barre transversale droite - gauche"},
|
53 |
+
16: {"name": "Left Goal Crossbar Left", "coords": [0.0, 0.0], "description": "Barre transversale gauche - gauche"},
|
54 |
+
17: {"name": "Left Goal Post Left", "coords": [0.0, 2.44], "description": "Poteau gauche but gauche"},
|
55 |
+
18: {"name": "Right Goal Post Right", "coords": [105.0, 2.44], "description": "Poteau droit but droit"},
|
56 |
+
19: {"name": "Right Goal Crossbar Right", "coords": [105.0, 0.0], "description": "Barre transversale droite - droite"},
|
57 |
+
20: {"name": "Left Small Box Bottom", "coords": [0.0, 0.0], "description": "Petite surface gauche bas"},
|
58 |
+
21: {"name": "Left Small Box Bottom Main", "coords": [5.5, 0.0], "description": "Petite surface gauche bas principale"},
|
59 |
+
22: {"name": "Right Small Box Bottom Main", "coords": [99.5, 0.0], "description": "Petite surface droite bas principale"},
|
60 |
+
23: {"name": "Right Small Box Bottom", "coords": [105.0, 0.0], "description": "Petite surface droite bas"},
|
61 |
+
24: {"name": "Left Big Box Bottom", "coords": [0.0, 0.0], "description": "Surface gauche bas"},
|
62 |
+
25: {"name": "Bottom Middle", "coords": [52.5, 0.0], "description": "Milieu ligne du bas"},
|
63 |
+
26: {"name": "Right Big Box Bottom", "coords": [105.0, 0.0], "description": "Surface droite bas"},
|
64 |
+
27: {"name": "Left Big Box Bottom Main", "coords": [16.5, 0.0], "description": "Surface gauche bas principale"},
|
65 |
+
28: {"name": "Bottom Middle Top", "coords": [52.5, 68.0], "description": "Milieu ligne opposée"},
|
66 |
+
29: {"name": "Right Big Box Bottom Main", "coords": [88.5, 0.0], "description": "Surface droite bas principale"},
|
67 |
+
30: {"name": "Left Penalty Spot", "coords": [11.0, 34.0], "description": "Point de penalty gauche"},
|
68 |
+
|
69 |
+
# Points auxiliaires (31-57)
|
70 |
+
31: {"name": "Left Box Aux 1", "coords": [16.5, 20.0], "description": "Point auxiliaire surface gauche 1"},
|
71 |
+
32: {"name": "Center Circle Left", "coords": [43.35, 34.0], "description": "Rond central gauche"},
|
72 |
+
33: {"name": "Center Circle Right", "coords": [61.65, 34.0], "description": "Rond central droit"},
|
73 |
+
34: {"name": "Right Box Aux 1", "coords": [88.5, 20.0], "description": "Point auxiliaire surface droite 1"},
|
74 |
+
35: {"name": "Left Box Aux 2", "coords": [16.5, 48.0], "description": "Point auxiliaire surface gauche 2"},
|
75 |
+
36: {"name": "Right Box Aux 2", "coords": [88.5, 48.0], "description": "Point auxiliaire surface droite 2"},
|
76 |
+
37: {"name": "Center Circle Top Left", "coords": [43.35, 24.85], "description": "Rond central haut gauche"},
|
77 |
+
38: {"name": "Center Circle Top Right", "coords": [61.65, 24.85], "description": "Rond central haut droit"},
|
78 |
+
39: {"name": "Center Circle Bottom Left", "coords": [43.35, 43.15], "description": "Rond central bas gauche"},
|
79 |
+
40: {"name": "Center Circle Bottom Right", "coords": [61.65, 43.15], "description": "Rond central bas droit"},
|
80 |
+
41: {"name": "Center Circle Left Inner", "coords": [46.03, 34.0], "description": "Rond central intérieur gauche"},
|
81 |
+
42: {"name": "Center Circle Right Inner", "coords": [58.97, 34.0], "description": "Rond central intérieur droit"},
|
82 |
+
43: {"name": "Center Circle Top Inner", "coords": [52.5, 27.53], "description": "Rond central intérieur haut"},
|
83 |
+
44: {"name": "Center Circle Bottom Inner", "coords": [52.5, 40.47], "description": "Rond central intérieur bas"},
|
84 |
+
45: {"name": "Left Penalty Arc Left", "coords": [19.99, 32.29], "description": "Arc penalty gauche - point gauche"},
|
85 |
+
46: {"name": "Left Penalty Arc Right", "coords": [19.99, 35.71], "description": "Arc penalty gauche - point droit"},
|
86 |
+
47: {"name": "Right Penalty Arc Left", "coords": [85.01, 32.29], "description": "Arc penalty droit - point gauche"},
|
87 |
+
48: {"name": "Right Penalty Arc Right", "coords": [85.01, 35.71], "description": "Arc penalty droit - point droit"},
|
88 |
+
49: {"name": "Left Penalty Arc Top", "coords": [16.5, 34.0], "description": "Arc penalty gauche - point haut"},
|
89 |
+
50: {"name": "Right Penalty Arc Top", "coords": [88.5, 34.0], "description": "Arc penalty droit - point haut"},
|
90 |
+
51: {"name": "Left Penalty Area Center", "coords": [16.5, 34.0], "description": "Centre surface de réparation gauche"},
|
91 |
+
52: {"name": "Right Penalty Area Center", "coords": [88.5, 34.0], "description": "Centre surface de réparation droite"},
|
92 |
+
53: {"name": "Right Penalty Spot", "coords": [94.0, 34.0], "description": "Point de penalty droit"},
|
93 |
+
54: {"name": "Center Spot", "coords": [52.5, 34.0], "description": "Point central"},
|
94 |
+
55: {"name": "Left Box Center", "coords": [16.5, 34.0], "description": "Centre surface gauche"},
|
95 |
+
56: {"name": "Right Box Center", "coords": [88.5, 34.0], "description": "Centre surface droite"},
|
96 |
+
57: {"name": "Center Circle Center", "coords": [52.5, 34.0], "description": "Centre rond central"}
|
97 |
+
}
|
98 |
+
|
99 |
+
# Field dimensions
|
100 |
+
FIELD_DIMENSIONS = {
|
101 |
+
"length": 105.0, # meters
|
102 |
+
"width": 68.0, # meters
|
103 |
+
"center_circle_radius": 9.15,
|
104 |
+
"penalty_area_length": 16.5,
|
105 |
+
"penalty_area_width": 40.32,
|
106 |
+
"goal_area_length": 5.5,
|
107 |
+
"goal_area_width": 18.32,
|
108 |
+
"penalty_spot_dist": 11.0,
|
109 |
+
"goal_height": 2.44,
|
110 |
+
"goal_width": 7.32
|
111 |
+
}
|
112 |
+
|
113 |
+
# Line intersections that form keypoints
|
114 |
+
KEYPOINT_LINE_PAIRS = [
|
115 |
+
["Side line top", "Side line left"],
|
116 |
+
["Side line top", "Middle line"],
|
117 |
+
["Side line right", "Side line top"],
|
118 |
+
["Side line left", "Big rect. left top"],
|
119 |
+
["Big rect. left top", "Big rect. left main"],
|
120 |
+
["Big rect. right top", "Big rect. right main"],
|
121 |
+
["Side line right", "Big rect. right top"],
|
122 |
+
["Side line left", "Small rect. left top"],
|
123 |
+
["Small rect. left top", "Small rect. left main"],
|
124 |
+
["Small rect. right top", "Small rect. right main"],
|
125 |
+
["Side line right", "Small rect. right top"],
|
126 |
+
["Goal left crossbar", "Goal left post right"],
|
127 |
+
["Side line left", "Goal left post right"],
|
128 |
+
["Side line right", "Goal right post left"],
|
129 |
+
["Goal right crossbar", "Goal right post left"],
|
130 |
+
["Goal left crossbar", "Goal left post left"],
|
131 |
+
["Side line left", "Goal left post left"],
|
132 |
+
["Side line right", "Goal right post right"],
|
133 |
+
["Goal right crossbar", "Goal right post right"],
|
134 |
+
["Side line left", "Small rect. left bottom"],
|
135 |
+
["Small rect. left bottom", "Small rect. left main"],
|
136 |
+
["Small rect. right bottom", "Small rect. right main"],
|
137 |
+
["Side line right", "Small rect. right bottom"],
|
138 |
+
["Side line left", "Big rect. left bottom"],
|
139 |
+
["Big rect. left bottom", "Big rect. left main"],
|
140 |
+
["Big rect. right bottom", "Big rect. right main"],
|
141 |
+
["Side line right", "Big rect. right bottom"],
|
142 |
+
["Side line left", "Side line bottom"],
|
143 |
+
["Side line right", "Side line bottom"]
|
144 |
+
]
|
145 |
+
|
146 |
+
# Auxiliary keypoint pairs
|
147 |
+
KEYPOINT_AUX_PAIRS = [
|
148 |
+
["Small rect. left main", "Side line top"],
|
149 |
+
["Big rect. left main", "Side line top"],
|
150 |
+
["Big rect. right main", "Side line top"],
|
151 |
+
["Small rect. right main", "Side line top"],
|
152 |
+
["Small rect. left main", "Big rect. left top"],
|
153 |
+
["Big rect. right top", "Small rect. right main"],
|
154 |
+
["Small rect. left top", "Big rect. left main"],
|
155 |
+
["Small rect. right top", "Big rect. right main"],
|
156 |
+
["Small rect. left bottom", "Big rect. left main"],
|
157 |
+
["Small rect. right bottom", "Big rect. right main"],
|
158 |
+
["Small rect. left main", "Big rect. left bottom"],
|
159 |
+
["Small rect. right main", "Big rect. right bottom"],
|
160 |
+
["Small rect. left main", "Side line bottom"],
|
161 |
+
["Big rect. left main", "Side line bottom"],
|
162 |
+
["Big rect. right main", "Side line bottom"],
|
163 |
+
["Small rect. right main", "Side line bottom"]
|
164 |
+
]
|
config/hrnetv2_w48.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
IMAGE_SIZE: [960, 540]
|
3 |
+
NUM_JOINTS: 58
|
4 |
+
PRETRAIN: ''
|
5 |
+
EXTRA:
|
6 |
+
FINAL_CONV_KERNEL: 1
|
7 |
+
STAGE1:
|
8 |
+
NUM_MODULES: 1
|
9 |
+
NUM_BRANCHES: 1
|
10 |
+
BLOCK: BOTTLENECK
|
11 |
+
NUM_BLOCKS: [4]
|
12 |
+
NUM_CHANNELS: [64]
|
13 |
+
FUSE_METHOD: SUM
|
14 |
+
STAGE2:
|
15 |
+
NUM_MODULES: 1
|
16 |
+
NUM_BRANCHES: 2
|
17 |
+
BLOCK: BASIC
|
18 |
+
NUM_BLOCKS: [4, 4]
|
19 |
+
NUM_CHANNELS: [48, 96]
|
20 |
+
FUSE_METHOD: SUM
|
21 |
+
STAGE3:
|
22 |
+
NUM_MODULES: 4
|
23 |
+
NUM_BRANCHES: 3
|
24 |
+
BLOCK: BASIC
|
25 |
+
NUM_BLOCKS: [4, 4, 4]
|
26 |
+
NUM_CHANNELS: [48, 96, 192]
|
27 |
+
FUSE_METHOD: SUM
|
28 |
+
STAGE4:
|
29 |
+
NUM_MODULES: 3
|
30 |
+
NUM_BRANCHES: 4
|
31 |
+
BLOCK: BASIC
|
32 |
+
NUM_BLOCKS: [4, 4, 4, 4]
|
33 |
+
NUM_CHANNELS: [48, 96, 192, 384]
|
34 |
+
FUSE_METHOD: SUM
|
35 |
+
|
config/hrnetv2_w48_l.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
IMAGE_SIZE: [960, 540]
|
3 |
+
NUM_JOINTS: 24
|
4 |
+
PRETRAIN: ''
|
5 |
+
EXTRA:
|
6 |
+
FINAL_CONV_KERNEL: 1
|
7 |
+
STAGE1:
|
8 |
+
NUM_MODULES: 1
|
9 |
+
NUM_BRANCHES: 1
|
10 |
+
BLOCK: BOTTLENECK
|
11 |
+
NUM_BLOCKS: [4]
|
12 |
+
NUM_CHANNELS: [64]
|
13 |
+
FUSE_METHOD: SUM
|
14 |
+
STAGE2:
|
15 |
+
NUM_MODULES: 1
|
16 |
+
NUM_BRANCHES: 2
|
17 |
+
BLOCK: BASIC
|
18 |
+
NUM_BLOCKS: [4, 4]
|
19 |
+
NUM_CHANNELS: [48, 96]
|
20 |
+
FUSE_METHOD: SUM
|
21 |
+
STAGE3:
|
22 |
+
NUM_MODULES: 4
|
23 |
+
NUM_BRANCHES: 3
|
24 |
+
BLOCK: BASIC
|
25 |
+
NUM_BLOCKS: [4, 4, 4]
|
26 |
+
NUM_CHANNELS: [48, 96, 192]
|
27 |
+
FUSE_METHOD: SUM
|
28 |
+
STAGE4:
|
29 |
+
NUM_MODULES: 3
|
30 |
+
NUM_BRANCHES: 4
|
31 |
+
BLOCK: BASIC
|
32 |
+
NUM_BLOCKS: [4, 4, 4, 4]
|
33 |
+
NUM_CHANNELS: [48, 96, 192, 384]
|
34 |
+
FUSE_METHOD: SUM
|
35 |
+
|
data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Fichier vide pour marquer le dossier comme un package Python
|
data/line_data.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contient cam1_line_dict et cam3_line_dict
|
2 |
+
|
3 |
+
cam1_line_dict = {
|
4 |
+
"Side line top": [
|
5 |
+
{"x": 0, "y": 205.35510086780727},
|
6 |
+
{"x": 226.69196710942427, "y": 218.52089231566922},
|
7 |
+
{"x": 609.2346616065778, "y": 235.73769651671947},
|
8 |
+
{"x": 1072.7387729285256, "y": 248.90348796458142},
|
9 |
+
{"x": 1642.5047438330162, "y": 257.005513470958},
|
10 |
+
{"x": 1918.7855787476271, "y": 257.005513470958}
|
11 |
+
],
|
12 |
+
"Big rect. right top": [
|
13 |
+
{"x": 1918.6823296329553, "y": 327.02322511812105},
|
14 |
+
{"x": 1761.8195845349164, "y": 327.02322511812105}
|
15 |
+
],
|
16 |
+
"Big rect. right main": [
|
17 |
+
{"x": 1761.8195845349164, "y": 327.02322511812105},
|
18 |
+
{"x": 1774.5760335197388, "y": 409.7284234765342},
|
19 |
+
{"x": 1790.8194197631253, "y": 526.9377258021154},
|
20 |
+
{"x": 1809.8670388107444, "y": 671.5805829449724}
|
21 |
+
],
|
22 |
+
"Big rect. right bottom": [
|
23 |
+
{"x": 1809.8670388107444, "y": 671.5805829449724},
|
24 |
+
{"x": 1917.4300640208282, "y": 677.1868952373312}
|
25 |
+
],
|
26 |
+
"Circle right": [
|
27 |
+
{"x": 1774.5760335197388, "y": 409.7284234765342},
|
28 |
+
{"x": 1709.2165563955555, "y": 410.6005164997901},
|
29 |
+
{"x": 1629.0422644565579, "y": 423.6819118486273},
|
30 |
+
{"x": 1576.7546827572116, "y": 442.8679583602552},
|
31 |
+
{"x": 1568.0400858073206, "y": 459.43772580211566},
|
32 |
+
{"x": 1592.4409572670156, "y": 482.11214440676673},
|
33 |
+
{"x": 1698.7590400556867, "y": 511.7633071974644},
|
34 |
+
{"x": 1790.8194197631253, "y": 526.9377258021154}
|
35 |
+
],
|
36 |
+
"Middle line": [
|
37 |
+
{"x": 0, "y": 308.65587267083185},
|
38 |
+
{"x": 226.69196710942427, "y": 218.52089231566922}
|
39 |
+
],
|
40 |
+
"Circle central": [
|
41 |
+
{"x": 0, "y": 340.8997444495873},
|
42 |
+
{"x": 122.94647588765228, "y": 351.0820197481417},
|
43 |
+
{"x": 234.87016428192882, "y": 366.3554326959732},
|
44 |
+
{"x": 305.2464228934815, "y": 390.1140750592667},
|
45 |
+
{"x": 309.48595654477987, "y": 413.8727174225603},
|
46 |
+
{"x": 228.08691043985144, "y": 441.0254515520386},
|
47 |
+
{"x": 71.22416534181235, "y": 463.08704803223975},
|
48 |
+
{"x": 4.239533651298354, "y": 468.1781856815169}
|
49 |
+
],
|
50 |
+
"Side line bottom": [
|
51 |
+
{"x": 1.3071895424836595, "y": 814.522819727929},
|
52 |
+
{"x": 636.6013071895421, "y": 827.6042150767662},
|
53 |
+
{"x": 1286.2745098039209, "y": 852.458866239557},
|
54 |
+
{"x": 1918.954248366012, "y": 889.086773216301}
|
55 |
+
]
|
56 |
+
}
|
57 |
+
|
58 |
+
cam3_line_dict = {
|
59 |
+
"Big rect. right top": [
|
60 |
+
{"x": 1342.8861505076343, "y": 1076.997434976179},
|
61 |
+
{"x": 1484.7446330310781, "y": 906.3705391217808}
|
62 |
+
],
|
63 |
+
"Big rect. right main": [
|
64 |
+
{"x": 1484.7446330310781, "y": 906.3705391217808},
|
65 |
+
{"x": 1049.6210183678218, "y": 748.0287797688992},
|
66 |
+
{"x": 828.6491513601493, "y": 668.8579000924583},
|
67 |
+
{"x": 349.8767728435256, "y": 500.9610345717304},
|
68 |
+
{"x": 32.736572890025556, "y": 397.21988189225624}
|
69 |
+
],
|
70 |
+
"Big rect. right bottom": [
|
71 |
+
{"x": 32.736572890025556, "y": 397.21988189225624},
|
72 |
+
{"x": 0.3753980224568448, "y": 407.0286292126068}
|
73 |
+
],
|
74 |
+
"Small rect. right top": [
|
75 |
+
{"x": 312.24913494809687, "y": 1075.6461846681693},
|
76 |
+
{"x": 426.66666666666663, "y": 999.9279904137233}
|
77 |
+
],
|
78 |
+
"Small rect. right main": [
|
79 |
+
{"x": 426.66666666666663, "y": 999.9279904137233},
|
80 |
+
{"x": 0, "y": 769.079837198949}
|
81 |
+
],
|
82 |
+
"Circle right": [
|
83 |
+
{"x": 828.6491513601493, "y": 668.8579000924583},
|
84 |
+
{"x": 821.7759602949911, "y": 612.2830792373484},
|
85 |
+
{"x": 782.8739995106773, "y": 564.5621490047902},
|
86 |
+
{"x": 722.6387053930304, "y": 529.3993583071158},
|
87 |
+
{"x": 623.5014504910696, "y": 503.02726528386006},
|
88 |
+
{"x": 494.24654853028534, "y": 492.980753655953},
|
89 |
+
{"x": 349.8767728435256, "y": 500.9610345717304}
|
90 |
+
],
|
91 |
+
"Side line bottom": [
|
92 |
+
{"x": 2.0193824656299317, "y": 266.2605192109321},
|
93 |
+
{"x": 399.0443993689428, "y": 186.14824976426013},
|
94 |
+
{"x": 645.5533017804819, "y": 132.93313314748357},
|
95 |
+
{"x": 1001.1088573360372, "y": 53.39824942655338},
|
96 |
+
{"x": 1208.1676808654488, "y": 7.351737798646435}
|
97 |
+
],
|
98 |
+
"Middle line": [
|
99 |
+
{"x": 645.5533017804819, "y": 132.93313314748357},
|
100 |
+
{"x": 1106.0585089650835, "y": 200.22939899146556},
|
101 |
+
{"x": 1580.7388158704541, "y": 269.8451725000601},
|
102 |
+
{"x": 1917.6527118636336, "y": 318.9857185061268}
|
103 |
+
],
|
104 |
+
"Circle central": [
|
105 |
+
{"x": 1580.7388158704541, "y": 269.8451725000601},
|
106 |
+
{"x": 1580.7388158704541, "y": 269.8451725000601},
|
107 |
+
{"x": 1533.8366024891266, "y": 288.8643838246303},
|
108 |
+
{"x": 1441.810458698277, "y": 302.46903498742097},
|
109 |
+
{"x": 1316.3202626198458, "y": 304.5620582432349},
|
110 |
+
{"x": 1219.0653606590615, "y": 292.0039187083512},
|
111 |
+
{"x": 1135.4052299401073, "y": 274.2132210339326},
|
112 |
+
{"x": 1069.522876998931, "y": 237.5853140571884},
|
113 |
+
{"x": 1106.0585089650835, "y": 200.22939899146556}
|
114 |
+
]
|
115 |
+
}
|
get_camera_params.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.utils_keypoints import KeypointsDB
|
2 |
+
from utils.utils_lines import LineKeypointsDB
|
3 |
+
from utils.utils_calib import FramebyFrameCalib
|
4 |
+
from utils.utils_heatmap import complete_keypoints
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
# Données de lignes pour cam3
|
10 |
+
cam3_line_dict = {
|
11 |
+
"Big rect. right top": [
|
12 |
+
{"x": 1342.8861505076343, "y": 1076.997434976179},
|
13 |
+
{"x": 1484.7446330310781, "y": 906.3705391217808}
|
14 |
+
],
|
15 |
+
"Big rect. right main": [
|
16 |
+
{"x": 1484.7446330310781, "y": 906.3705391217808},
|
17 |
+
{"x": 1049.6210183678218, "y": 748.0287797688992},
|
18 |
+
{"x": 828.6491513601493, "y": 668.8579000924583},
|
19 |
+
{"x": 349.8767728435256, "y": 500.9610345717304},
|
20 |
+
{"x": 32.736572890025556, "y": 397.21988189225624}
|
21 |
+
],
|
22 |
+
"Big rect. right bottom": [
|
23 |
+
{"x": 32.736572890025556, "y": 397.21988189225624},
|
24 |
+
{"x": 0.3753980224568448, "y": 407.0286292126068}
|
25 |
+
],
|
26 |
+
"Small rect. right top": [
|
27 |
+
{"x": 312.24913494809687, "y": 1075.6461846681693},
|
28 |
+
{"x": 426.66666666666663, "y": 999.9279904137233}
|
29 |
+
],
|
30 |
+
"Small rect. right main": [
|
31 |
+
{"x": 426.66666666666663, "y": 999.9279904137233},
|
32 |
+
{"x": 0, "y": 769.079837198949}
|
33 |
+
],
|
34 |
+
"Circle right": [
|
35 |
+
{"x": 828.6491513601493, "y": 668.8579000924583},
|
36 |
+
{"x": 821.7759602949911, "y": 612.2830792373484},
|
37 |
+
{"x": 782.8739995106773, "y": 564.5621490047902},
|
38 |
+
{"x": 722.6387053930304, "y": 529.3993583071158},
|
39 |
+
{"x": 623.5014504910696, "y": 503.02726528386006},
|
40 |
+
{"x": 494.24654853028534, "y": 492.980753655953},
|
41 |
+
{"x": 349.8767728435256, "y": 500.9610345717304}
|
42 |
+
],
|
43 |
+
"Side line bottom": [
|
44 |
+
{"x": 2.0193824656299317, "y": 266.2605192109321},
|
45 |
+
{"x": 399.0443993689428, "y": 186.14824976426013},
|
46 |
+
{"x": 645.5533017804819, "y": 132.93313314748357},
|
47 |
+
{"x": 1001.1088573360372, "y": 53.39824942655338},
|
48 |
+
{"x": 1208.1676808654488, "y": 7.351737798646435}
|
49 |
+
],
|
50 |
+
"Middle line": [
|
51 |
+
{"x": 645.5533017804819, "y": 132.93313314748357},
|
52 |
+
{"x": 1106.0585089650835, "y": 200.22939899146556},
|
53 |
+
{"x": 1580.7388158704541, "y": 269.8451725000601},
|
54 |
+
{"x": 1917.6527118636336, "y": 318.9857185061268}
|
55 |
+
],
|
56 |
+
"Circle central": [
|
57 |
+
{"x": 1580.7388158704541, "y": 269.8451725000601},
|
58 |
+
{"x": 1580.7388158704541, "y": 269.8451725000601},
|
59 |
+
{"x": 1533.8366024891266, "y": 288.8643838246303},
|
60 |
+
{"x": 1441.810458698277, "y": 302.46903498742097},
|
61 |
+
{"x": 1316.3202626198458, "y": 304.5620582432349},
|
62 |
+
{"x": 1219.0653606590615, "y": 292.0039187083512},
|
63 |
+
{"x": 1135.4052299401073, "y": 274.2132210339326},
|
64 |
+
{"x": 1069.522876998931, "y": 237.5853140571884},
|
65 |
+
{"x": 1106.0585089650835, "y": 200.22939899146556},
|
66 |
+
{"x": 1139.5882364760548, "y": 189.4457791734675},
|
67 |
+
{"x": 1224.2941188289963, "y": 177.9341512664908},
|
68 |
+
{"x": 1314.2287593518718, "y": 174.79461638276985},
|
69 |
+
{"x": 1392.6601319008914, "y": 180.02717452230473},
|
70 |
+
{"x": 1465.8627462799764, "y": 190.49229080137454},
|
71 |
+
{"x": 1529.6535959531789, "y": 204.09694196416518},
|
72 |
+
{"x": 1581.9411776525253, "y": 230.2597326618396},
|
73 |
+
{"x": 1580.7388158704541, "y": 269.8451725000601}
|
74 |
+
],
|
75 |
+
"Side line left": [
|
76 |
+
{"x": 1208.1676808654488, "y": 7.351737798646435},
|
77 |
+
{"x": 1401.9652021886754, "y": 20.565213248502545},
|
78 |
+
{"x": 1582.3573590514204, "y": 30.37625976013045},
|
79 |
+
{"x": 1679.416182580832, "y": 34.300678364781604},
|
80 |
+
{"x": 1824.5142217965183, "y": 41.23091697692868},
|
81 |
+
{"x": 1918.6318688553417, "y": 42.21202162809147}
|
82 |
+
],
|
83 |
+
"Big rect. left bottom": [
|
84 |
+
{"x": 1401.9652021886754, "y": 20.565213248502545},
|
85 |
+
{"x": 1283.3377512082834, "y": 53.98527744204496}
|
86 |
+
],
|
87 |
+
"Big rect. left main": [
|
88 |
+
{"x": 1283.3377512082834, "y": 53.98527744204496},
|
89 |
+
{"x": 1510.7887316004399, "y": 73.60737046530076},
|
90 |
+
{"x": 1808.8279472867146, "y": 94.21056813971936},
|
91 |
+
{"x": 1918.6318688553417, "y": 100.0971960466961}
|
92 |
+
],
|
93 |
+
"Circle left": [
|
94 |
+
{"x": 1510.7887316004399, "y": 73.60737046530076},
|
95 |
+
{"x": 1548.0436335612244, "y": 86.36173093041702},
|
96 |
+
{"x": 1620.5926531690673, "y": 95.19167279088215},
|
97 |
+
{"x": 1681.3769668945574, "y": 97.15388209320773},
|
98 |
+
{"x": 1746.0828492474989, "y": 100.0971960466961},
|
99 |
+
{"x": 1808.8279472867146, "y": 94.21056813971936}
|
100 |
+
],
|
101 |
+
"Small rect. left bottom": [
|
102 |
+
{"x": 1550.9848100318127, "y": 42.21202162809147},
|
103 |
+
{"x": 1582.3573590514204, "y": 30.37625976013045}
|
104 |
+
],
|
105 |
+
"Small rect. left main": [
|
106 |
+
{"x": 1550.9848100318127, "y": 42.21202162809147},
|
107 |
+
{"x": 1918.418689198772, "y": 60.49417894940041}
|
108 |
+
]
|
109 |
+
}
|
110 |
+
|
111 |
+
def transform_data(line_dict, width, height):
|
112 |
+
"""
|
113 |
+
Transform input line dictionary to normalized coordinates.
|
114 |
+
"""
|
115 |
+
transformed = {}
|
116 |
+
for line_name, points in line_dict.items():
|
117 |
+
transformed[line_name] = []
|
118 |
+
for point in points:
|
119 |
+
transformed[line_name].append({
|
120 |
+
"x": point["x"] / width,
|
121 |
+
"y": point["y"] / height
|
122 |
+
})
|
123 |
+
return transformed
|
124 |
+
|
125 |
+
def get_camera_parameters(image_path, line_dict):
|
126 |
+
"""
|
127 |
+
Extract camera parameters from image and line data.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
image_path (str): Path to the image file
|
131 |
+
line_dict (dict): Dictionary containing line coordinates
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
dict: Camera parameters
|
135 |
+
"""
|
136 |
+
# Load image
|
137 |
+
image = Image.open(image_path)
|
138 |
+
image_tensor = torch.FloatTensor(np.array(image)).permute(2, 0, 1)
|
139 |
+
|
140 |
+
# Get image dimensions
|
141 |
+
img_width, img_height = image.size
|
142 |
+
|
143 |
+
# Transform data using actual image dimensions
|
144 |
+
trans_data = transform_data(line_dict, img_width, img_height)
|
145 |
+
|
146 |
+
# Initialize databases
|
147 |
+
kp_db = KeypointsDB(trans_data, image_tensor)
|
148 |
+
ln_db = LineKeypointsDB(trans_data, image_tensor)
|
149 |
+
|
150 |
+
# Get keypoints and lines
|
151 |
+
kp_db.get_full_keypoints()
|
152 |
+
ln_db.get_lines()
|
153 |
+
|
154 |
+
kp_dict = kp_db.keypoints_final
|
155 |
+
ln_dict = ln_db.lines
|
156 |
+
|
157 |
+
# Complete keypoints
|
158 |
+
kp_dict, ln_dict = complete_keypoints(kp_dict, ln_dict, img_width, img_height)
|
159 |
+
|
160 |
+
# Initialize calibration
|
161 |
+
cam = FramebyFrameCalib(img_width, img_height)
|
162 |
+
cam.update(kp_dict, ln_dict)
|
163 |
+
cam_params = cam.heuristic_voting(refine_lines=True)
|
164 |
+
|
165 |
+
return cam_params
|
166 |
+
|
167 |
+
def main():
|
168 |
+
# Chemin vers votre image
|
169 |
+
image_path = "examples/input/cam3.jpg"
|
170 |
+
|
171 |
+
# Obtenir les paramètres de la caméra
|
172 |
+
camera_params = get_camera_parameters(image_path, cam3_line_dict)
|
173 |
+
|
174 |
+
# Afficher les paramètres
|
175 |
+
print("=== PARAMÈTRES DE LA CAMÉRA ===")
|
176 |
+
print(f"Position (mètres): {camera_params['cam_params']['position_meters']}")
|
177 |
+
print(f"Distance focale X: {camera_params['cam_params']['x_focal_length']:.2f}")
|
178 |
+
print(f"Distance focale Y: {camera_params['cam_params']['y_focal_length']:.2f}")
|
179 |
+
print(f"Point principal: {camera_params['cam_params']['principal_point']}")
|
180 |
+
print(f"Matrice de rotation:")
|
181 |
+
rotation_matrix = np.array(camera_params['cam_params']['rotation_matrix'])
|
182 |
+
print(rotation_matrix)
|
183 |
+
|
184 |
+
# Calcul des angles d'Euler
|
185 |
+
euler_angles = np.array([
|
186 |
+
np.arctan2(rotation_matrix[2,1], rotation_matrix[2,2]), # roll
|
187 |
+
np.arctan2(-rotation_matrix[2,0], np.sqrt(rotation_matrix[2,1]**2 + rotation_matrix[2,2]**2)), # pitch
|
188 |
+
np.arctan2(rotation_matrix[1,0], rotation_matrix[0,0]) # yaw
|
189 |
+
]) * 180 / np.pi
|
190 |
+
|
191 |
+
print(f"Angles d'Euler (degrés):")
|
192 |
+
print(f" Roll: {euler_angles[0]:.1f}°")
|
193 |
+
print(f" Pitch: {euler_angles[1]:.1f}°")
|
194 |
+
print(f" Yaw: {euler_angles[2]:.1f}°")
|
195 |
+
|
196 |
+
# print(camera_params)
|
197 |
+
|
198 |
+
return camera_params
|
199 |
+
|
200 |
+
if __name__ == "__main__":
|
201 |
+
main()
|
inference.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import yaml
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import torchvision.transforms as T
|
8 |
+
import torchvision.transforms.functional as f
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
from PIL import Image
|
12 |
+
from matplotlib.patches import Polygon
|
13 |
+
|
14 |
+
from model.cls_hrnet import get_cls_net
|
15 |
+
from model.cls_hrnet_l import get_cls_net as get_cls_net_l
|
16 |
+
|
17 |
+
from utils.utils_calib import FramebyFrameCalib, pan_tilt_roll_to_orientation
|
18 |
+
from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, \
|
19 |
+
complete_keypoints, coords_to_dict
|
20 |
+
|
21 |
+
|
22 |
+
lines_coords = [[[0., 54.16, 0.], [16.5, 54.16, 0.]],
|
23 |
+
[[16.5, 13.84, 0.], [16.5, 54.16, 0.]],
|
24 |
+
[[16.5, 13.84, 0.], [0., 13.84, 0.]],
|
25 |
+
[[88.5, 54.16, 0.], [105., 54.16, 0.]],
|
26 |
+
[[88.5, 13.84, 0.], [88.5, 54.16, 0.]],
|
27 |
+
[[88.5, 13.84, 0.], [105., 13.84, 0.]],
|
28 |
+
[[0., 37.66, -2.44], [0., 30.34, -2.44]],
|
29 |
+
[[0., 37.66, 0.], [0., 37.66, -2.44]],
|
30 |
+
[[0., 30.34, 0.], [0., 30.34, -2.44]],
|
31 |
+
[[105., 37.66, -2.44], [105., 30.34, -2.44]],
|
32 |
+
[[105., 30.34, 0.], [105., 30.34, -2.44]],
|
33 |
+
[[105., 37.66, 0.], [105., 37.66, -2.44]],
|
34 |
+
[[52.5, 0., 0.], [52.5, 68, 0.]],
|
35 |
+
[[0., 68., 0.], [105., 68., 0.]],
|
36 |
+
[[0., 0., 0.], [0., 68., 0.]],
|
37 |
+
[[105., 0., 0.], [105., 68., 0.]],
|
38 |
+
[[0., 0., 0.], [105., 0., 0.]],
|
39 |
+
[[0., 43.16, 0.], [5.5, 43.16, 0.]],
|
40 |
+
[[5.5, 43.16, 0.], [5.5, 24.84, 0.]],
|
41 |
+
[[5.5, 24.84, 0.], [0., 24.84, 0.]],
|
42 |
+
[[99.5, 43.16, 0.], [105., 43.16, 0.]],
|
43 |
+
[[99.5, 43.16, 0.], [99.5, 24.84, 0.]],
|
44 |
+
[[99.5, 24.84, 0.], [105., 24.84, 0.]]]
|
45 |
+
|
46 |
+
|
47 |
+
def projection_from_cam_params(final_params_dict):
|
48 |
+
cam_params = final_params_dict["cam_params"]
|
49 |
+
x_focal_length = cam_params['x_focal_length']
|
50 |
+
y_focal_length = cam_params['y_focal_length']
|
51 |
+
principal_point = np.array(cam_params['principal_point'])
|
52 |
+
position_meters = np.array(cam_params['position_meters'])
|
53 |
+
rotation = np.array(cam_params['rotation_matrix'])
|
54 |
+
|
55 |
+
It = np.eye(4)[:-1]
|
56 |
+
It[:, -1] = -position_meters
|
57 |
+
Q = np.array([[x_focal_length, 0, principal_point[0]],
|
58 |
+
[0, y_focal_length, principal_point[1]],
|
59 |
+
[0, 0, 1]])
|
60 |
+
P = Q @ (rotation @ It)
|
61 |
+
|
62 |
+
return P
|
63 |
+
|
64 |
+
|
65 |
+
def inference(cam, frame, model, model_l, kp_threshold, line_threshold, pnl_refine):
|
66 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
67 |
+
frame = Image.fromarray(frame)
|
68 |
+
|
69 |
+
frame = f.to_tensor(frame).float().unsqueeze(0)
|
70 |
+
_, _, h_original, w_original = frame.size()
|
71 |
+
frame = frame if frame.size()[-1] == 960 else transform2(frame)
|
72 |
+
frame = frame.to(device)
|
73 |
+
b, c, h, w = frame.size()
|
74 |
+
|
75 |
+
with torch.no_grad():
|
76 |
+
heatmaps = model(frame)
|
77 |
+
heatmaps_l = model_l(frame)
|
78 |
+
|
79 |
+
kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
|
80 |
+
# print("\n"*4,"kp_coords: ", kp_coords)
|
81 |
+
line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
|
82 |
+
print("\n"*4,"line_coords_from_heatmap: ", line_coords)
|
83 |
+
kp_dict = coords_to_dict(kp_coords, threshold=kp_threshold)
|
84 |
+
lines_dict = coords_to_dict(line_coords, threshold=line_threshold)
|
85 |
+
|
86 |
+
# print("\n=== AVANT complete_keypoints ===")
|
87 |
+
# print("--- KEYPOINTS ---")
|
88 |
+
# print(f"Nombre de keypoints: {len(kp_dict[0])}")
|
89 |
+
# print("Structure des keypoints:")
|
90 |
+
for kp_key, kp_value in kp_dict[0].items():
|
91 |
+
print(f"{kp_key}: {kp_value}")
|
92 |
+
|
93 |
+
# print("\n--- LIGNES ---")
|
94 |
+
# print(f"Nombre de lignes: {len(lines_dict[0])}")
|
95 |
+
# print("Structure des lignes:")
|
96 |
+
for line_key, line_value in lines_dict[0].items():
|
97 |
+
print(f"{line_key}: {line_value}")
|
98 |
+
|
99 |
+
kp_dict, lines_dict = complete_keypoints(kp_dict[0], lines_dict[0], w=w, h=h, normalize=True)
|
100 |
+
|
101 |
+
# print("\n=== APRÈS complete_keypoints ===")
|
102 |
+
# print("--- KEYPOINTS ---")
|
103 |
+
# print(f"Nombre de keypoints: {len(kp_dict)}")
|
104 |
+
# print("Structure des keypoints:")
|
105 |
+
for kp_key, kp_value in kp_dict.items():
|
106 |
+
print(f"{kp_key}: {kp_value}")
|
107 |
+
|
108 |
+
# print("\n--- LIGNES ---")
|
109 |
+
# print(f"Nombre de lignes: {len(lines_dict)}")
|
110 |
+
# print("Structure des lignes:")
|
111 |
+
for line_key, line_value in lines_dict.items():
|
112 |
+
print(f"{line_key}: {line_value}")
|
113 |
+
|
114 |
+
cam.update(kp_dict, lines_dict)
|
115 |
+
final_params_dict = cam.heuristic_voting(refine_lines=pnl_refine)
|
116 |
+
|
117 |
+
return final_params_dict
|
118 |
+
|
119 |
+
|
120 |
+
def project(frame, P):
|
121 |
+
|
122 |
+
for line in lines_coords:
|
123 |
+
w1 = line[0]
|
124 |
+
w2 = line[1]
|
125 |
+
i1 = P @ np.array([w1[0]-105/2, w1[1]-68/2, w1[2], 1])
|
126 |
+
i2 = P @ np.array([w2[0]-105/2, w2[1]-68/2, w2[2], 1])
|
127 |
+
i1 /= i1[-1]
|
128 |
+
i2 /= i2[-1]
|
129 |
+
frame = cv2.line(frame, (int(i1[0]), int(i1[1])), (int(i2[0]), int(i2[1])), (255, 0, 0), 3)
|
130 |
+
|
131 |
+
r = 9.15
|
132 |
+
pts1, pts2, pts3 = [], [], []
|
133 |
+
base_pos = np.array([11-105/2, 68/2-68/2, 0., 0.])
|
134 |
+
for ang in np.linspace(37, 143, 50):
|
135 |
+
ang = np.deg2rad(ang)
|
136 |
+
pos = base_pos + np.array([r*np.sin(ang), r*np.cos(ang), 0., 1.])
|
137 |
+
ipos = P @ pos
|
138 |
+
ipos /= ipos[-1]
|
139 |
+
pts1.append([ipos[0], ipos[1]])
|
140 |
+
|
141 |
+
base_pos = np.array([94-105/2, 68/2-68/2, 0., 0.])
|
142 |
+
for ang in np.linspace(217, 323, 200):
|
143 |
+
ang = np.deg2rad(ang)
|
144 |
+
pos = base_pos + np.array([r*np.sin(ang), r*np.cos(ang), 0., 1.])
|
145 |
+
ipos = P @ pos
|
146 |
+
ipos /= ipos[-1]
|
147 |
+
pts2.append([ipos[0], ipos[1]])
|
148 |
+
|
149 |
+
base_pos = np.array([0, 0, 0., 0.])
|
150 |
+
for ang in np.linspace(0, 360, 500):
|
151 |
+
ang = np.deg2rad(ang)
|
152 |
+
pos = base_pos + np.array([r*np.sin(ang), r*np.cos(ang), 0., 1.])
|
153 |
+
ipos = P @ pos
|
154 |
+
ipos /= ipos[-1]
|
155 |
+
pts3.append([ipos[0], ipos[1]])
|
156 |
+
|
157 |
+
XEllipse1 = np.array(pts1, np.int32)
|
158 |
+
XEllipse2 = np.array(pts2, np.int32)
|
159 |
+
XEllipse3 = np.array(pts3, np.int32)
|
160 |
+
frame = cv2.polylines(frame, [XEllipse1], False, (255, 0, 0), 3)
|
161 |
+
frame = cv2.polylines(frame, [XEllipse2], False, (255, 0, 0), 3)
|
162 |
+
frame = cv2.polylines(frame, [XEllipse3], False, (255, 0, 0), 3)
|
163 |
+
|
164 |
+
return frame
|
165 |
+
|
166 |
+
|
167 |
+
def process_input(input_path, input_type, model_kp, model_line, kp_threshold, line_threshold, pnl_refine,
|
168 |
+
save_path, display):
|
169 |
+
|
170 |
+
cap = cv2.VideoCapture(input_path)
|
171 |
+
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
172 |
+
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
173 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
174 |
+
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
175 |
+
|
176 |
+
cam = FramebyFrameCalib(iwidth=frame_width, iheight=frame_height, denormalize=True)
|
177 |
+
|
178 |
+
if input_type == 'video':
|
179 |
+
cap = cv2.VideoCapture(input_path)
|
180 |
+
if save_path != "":
|
181 |
+
out = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
|
182 |
+
|
183 |
+
pbar = tqdm(total=total_frames)
|
184 |
+
|
185 |
+
while cap.isOpened():
|
186 |
+
ret, frame = cap.read()
|
187 |
+
if not ret:
|
188 |
+
break
|
189 |
+
|
190 |
+
final_params_dict = inference(cam, frame, model, model_l, kp_threshold, line_threshold, pnl_refine)
|
191 |
+
if final_params_dict is not None:
|
192 |
+
P = projection_from_cam_params(final_params_dict)
|
193 |
+
projected_frame = project(frame, P)
|
194 |
+
else:
|
195 |
+
projected_frame = frame
|
196 |
+
|
197 |
+
if save_path != "":
|
198 |
+
out.write(projected_frame)
|
199 |
+
|
200 |
+
if display:
|
201 |
+
cv2.imshow('Projected Frame', projected_frame)
|
202 |
+
if cv2.waitKey(1) & 0xFF == ord('q'):
|
203 |
+
break
|
204 |
+
|
205 |
+
pbar.update(1)
|
206 |
+
|
207 |
+
cap.release()
|
208 |
+
if save_path != "":
|
209 |
+
out.release()
|
210 |
+
cv2.destroyAllWindows()
|
211 |
+
|
212 |
+
elif input_type == 'image':
|
213 |
+
frame = cv2.imread(input_path)
|
214 |
+
if frame is None:
|
215 |
+
print(f"Error: Unable to read the image {input_path}")
|
216 |
+
return
|
217 |
+
|
218 |
+
final_params_dict = inference(cam, frame, model, model_l, kp_threshold, line_threshold, pnl_refine)
|
219 |
+
|
220 |
+
print("\n"*4,"final_params_dict: ", final_params_dict)
|
221 |
+
|
222 |
+
if final_params_dict is not None:
|
223 |
+
P = projection_from_cam_params(final_params_dict)
|
224 |
+
projected_frame = project(frame, P)
|
225 |
+
else:
|
226 |
+
projected_frame = frame
|
227 |
+
|
228 |
+
if save_path != "":
|
229 |
+
cv2.imwrite(save_path, projected_frame)
|
230 |
+
else:
|
231 |
+
plt.imshow(cv2.cvtColor(projected_frame, cv2.COLOR_BGR2RGB))
|
232 |
+
plt.axis('off')
|
233 |
+
plt.show()
|
234 |
+
|
235 |
+
if __name__ == "__main__":
|
236 |
+
|
237 |
+
parser = argparse.ArgumentParser(description="Process video or image and plot lines on each frame.")
|
238 |
+
parser.add_argument("--weights_kp", type=str, help="Path to the model for keypoint inference.")
|
239 |
+
parser.add_argument("--weights_line", type=str, help="Path to the model for line projection.")
|
240 |
+
parser.add_argument("--kp_threshold", type=float, default=0.3434, help="Threshold for keypoint detection.")
|
241 |
+
parser.add_argument("--line_threshold", type=float, default=0.7867, help="Threshold for line detection.")
|
242 |
+
parser.add_argument("--pnl_refine", action="store_true", help="Enable PnL refinement module.")
|
243 |
+
parser.add_argument("--device", type=str, default="cuda:0", help="CPU or CUDA device index")
|
244 |
+
parser.add_argument("--input_path", type=str, required=True, help="Path to the input video or image file.")
|
245 |
+
parser.add_argument("--input_type", type=str, choices=['video', 'image'], required=True,
|
246 |
+
help="Type of input: 'video' or 'image'.")
|
247 |
+
parser.add_argument("--save_path", type=str, default="", help="Path to save the processed video.")
|
248 |
+
parser.add_argument("--display", action="store_true", help="Enable real-time display.")
|
249 |
+
args = parser.parse_args()
|
250 |
+
|
251 |
+
|
252 |
+
input_path = args.input_path
|
253 |
+
input_type = args.input_type
|
254 |
+
model_kp = args.weights_kp
|
255 |
+
model_line = args.weights_line
|
256 |
+
pnl_refine = args.pnl_refine
|
257 |
+
save_path = args.save_path
|
258 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
259 |
+
display = args.display and input_type == 'video'
|
260 |
+
kp_threshold = args.kp_threshold
|
261 |
+
line_threshold = args.line_threshold
|
262 |
+
|
263 |
+
cfg = yaml.safe_load(open("config/hrnetv2_w48.yaml", 'r'))
|
264 |
+
cfg_l = yaml.safe_load(open("config/hrnetv2_w48_l.yaml", 'r'))
|
265 |
+
|
266 |
+
loaded_state = torch.load(args.weights_kp, map_location=device, weights_only=True)
|
267 |
+
model = get_cls_net(cfg)
|
268 |
+
model.load_state_dict(loaded_state)
|
269 |
+
model.to(device)
|
270 |
+
model.eval()
|
271 |
+
|
272 |
+
loaded_state_l = torch.load(args.weights_line, map_location=device, weights_only=True)
|
273 |
+
model_l = get_cls_net_l(cfg_l)
|
274 |
+
model_l.load_state_dict(loaded_state_l)
|
275 |
+
model_l.to(device)
|
276 |
+
model_l.eval()
|
277 |
+
|
278 |
+
transform2 = T.Resize((540, 960))
|
279 |
+
|
280 |
+
process_input(input_path, input_type, model_kp, model_line, kp_threshold, line_threshold, pnl_refine,
|
281 |
+
save_path, display)
|
282 |
+
|
283 |
+
|
284 |
+
# python inference.py --weights_kp "models/SV_FT_TSWC_kp" --weights_line "models/SV_FT_TSWC_lines" --input_path "examples/input/FootDrone.mp4" --input_type "video" --save_path "examples/output/video.mp4" --kp_threshold 0.15 --line_threshold 0.15 --pnl_refine
|
285 |
+
# python inference.py --weights_kp "SV_FT_TSWC_kp" --weights_line "SV_FT_TSWC_lines" --input_path "examples/input/FootDrone.jpg" --input_type "image" --save_path "examples/output/FootDrone_inf.jpg" --kp_threshold 0.15 --line_threshold 0.15
|
286 |
+
# python inference.py --weights_kp "models/SV_FT_TSWC_kp" --weights_line "models/SV_FT_TSWC_lines" --input_path "examples/input/fisheye_messi.png" --input_type "image" --save_path "examples/output/fisheye_messi_inf.png" --kp_threshold 0.15 --line_threshold 0.15
|
model/cls_hrnet.py
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import os
|
6 |
+
import logging
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
BatchNorm2d = nn.BatchNorm2d
|
14 |
+
BN_MOMENTUM = 0.1
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
19 |
+
"""3x3 convolution with padding"""
|
20 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
21 |
+
stride=stride, padding=1, bias=False)
|
22 |
+
|
23 |
+
|
24 |
+
class BasicBlock(nn.Module):
|
25 |
+
expansion = 1
|
26 |
+
|
27 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
28 |
+
super(BasicBlock, self).__init__()
|
29 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
30 |
+
self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
31 |
+
self.relu = nn.ReLU(inplace=True)
|
32 |
+
self.conv2 = conv3x3(planes, planes)
|
33 |
+
self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
34 |
+
self.downsample = downsample
|
35 |
+
self.stride = stride
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
residual = x
|
39 |
+
|
40 |
+
out = self.conv1(x)
|
41 |
+
out = self.bn1(out)
|
42 |
+
out = self.relu(out)
|
43 |
+
|
44 |
+
out = self.conv2(out)
|
45 |
+
out = self.bn2(out)
|
46 |
+
|
47 |
+
if self.downsample is not None:
|
48 |
+
residual = self.downsample(x)
|
49 |
+
|
50 |
+
out += residual
|
51 |
+
out = self.relu(out)
|
52 |
+
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
class Bottleneck(nn.Module):
|
57 |
+
expansion = 4
|
58 |
+
|
59 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
60 |
+
super(Bottleneck, self).__init__()
|
61 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
62 |
+
self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
63 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
64 |
+
padding=1, bias=False)
|
65 |
+
self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
66 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
|
67 |
+
bias=False)
|
68 |
+
self.bn3 = BatchNorm2d(planes * self.expansion,
|
69 |
+
momentum=BN_MOMENTUM)
|
70 |
+
self.relu = nn.ReLU(inplace=True)
|
71 |
+
self.downsample = downsample
|
72 |
+
self.stride = stride
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
residual = x
|
76 |
+
|
77 |
+
out = self.conv1(x)
|
78 |
+
out = self.bn1(out)
|
79 |
+
out = self.relu(out)
|
80 |
+
|
81 |
+
out = self.conv2(out)
|
82 |
+
out = self.bn2(out)
|
83 |
+
out = self.relu(out)
|
84 |
+
|
85 |
+
out = self.conv3(out)
|
86 |
+
out = self.bn3(out)
|
87 |
+
|
88 |
+
if self.downsample is not None:
|
89 |
+
residual = self.downsample(x)
|
90 |
+
|
91 |
+
out += residual
|
92 |
+
out = self.relu(out)
|
93 |
+
|
94 |
+
return out
|
95 |
+
|
96 |
+
|
97 |
+
class HighResolutionModule(nn.Module):
|
98 |
+
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
|
99 |
+
num_channels, fuse_method, multi_scale_output=True):
|
100 |
+
super(HighResolutionModule, self).__init__()
|
101 |
+
self._check_branches(
|
102 |
+
num_branches, blocks, num_blocks, num_inchannels, num_channels)
|
103 |
+
|
104 |
+
self.num_inchannels = num_inchannels
|
105 |
+
self.fuse_method = fuse_method
|
106 |
+
self.num_branches = num_branches
|
107 |
+
|
108 |
+
self.multi_scale_output = multi_scale_output
|
109 |
+
|
110 |
+
self.branches = self._make_branches(
|
111 |
+
num_branches, blocks, num_blocks, num_channels)
|
112 |
+
self.fuse_layers = self._make_fuse_layers()
|
113 |
+
self.relu = nn.ReLU(inplace=True)
|
114 |
+
|
115 |
+
def _check_branches(self, num_branches, blocks, num_blocks,
|
116 |
+
num_inchannels, num_channels):
|
117 |
+
if num_branches != len(num_blocks):
|
118 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
|
119 |
+
num_branches, len(num_blocks))
|
120 |
+
logger.error(error_msg)
|
121 |
+
raise ValueError(error_msg)
|
122 |
+
|
123 |
+
if num_branches != len(num_channels):
|
124 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
|
125 |
+
num_branches, len(num_channels))
|
126 |
+
logger.error(error_msg)
|
127 |
+
raise ValueError(error_msg)
|
128 |
+
|
129 |
+
if num_branches != len(num_inchannels):
|
130 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
|
131 |
+
num_branches, len(num_inchannels))
|
132 |
+
logger.error(error_msg)
|
133 |
+
raise ValueError(error_msg)
|
134 |
+
|
135 |
+
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
|
136 |
+
stride=1):
|
137 |
+
downsample = None
|
138 |
+
if stride != 1 or \
|
139 |
+
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
|
140 |
+
downsample = nn.Sequential(
|
141 |
+
nn.Conv2d(self.num_inchannels[branch_index],
|
142 |
+
num_channels[branch_index] * block.expansion,
|
143 |
+
kernel_size=1, stride=stride, bias=False),
|
144 |
+
BatchNorm2d(num_channels[branch_index] * block.expansion,
|
145 |
+
momentum=BN_MOMENTUM),
|
146 |
+
)
|
147 |
+
|
148 |
+
layers = []
|
149 |
+
layers.append(block(self.num_inchannels[branch_index],
|
150 |
+
num_channels[branch_index], stride, downsample))
|
151 |
+
self.num_inchannels[branch_index] = \
|
152 |
+
num_channels[branch_index] * block.expansion
|
153 |
+
for i in range(1, num_blocks[branch_index]):
|
154 |
+
layers.append(block(self.num_inchannels[branch_index],
|
155 |
+
num_channels[branch_index]))
|
156 |
+
|
157 |
+
return nn.Sequential(*layers)
|
158 |
+
|
159 |
+
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
160 |
+
branches = []
|
161 |
+
|
162 |
+
for i in range(num_branches):
|
163 |
+
branches.append(
|
164 |
+
self._make_one_branch(i, block, num_blocks, num_channels))
|
165 |
+
|
166 |
+
return nn.ModuleList(branches)
|
167 |
+
|
168 |
+
def _make_fuse_layers(self):
|
169 |
+
if self.num_branches == 1:
|
170 |
+
return None
|
171 |
+
|
172 |
+
num_branches = self.num_branches
|
173 |
+
num_inchannels = self.num_inchannels
|
174 |
+
fuse_layers = []
|
175 |
+
for i in range(num_branches if self.multi_scale_output else 1):
|
176 |
+
fuse_layer = []
|
177 |
+
for j in range(num_branches):
|
178 |
+
if j > i:
|
179 |
+
fuse_layer.append(nn.Sequential(
|
180 |
+
nn.Conv2d(num_inchannels[j],
|
181 |
+
num_inchannels[i],
|
182 |
+
1,
|
183 |
+
1,
|
184 |
+
0,
|
185 |
+
bias=False),
|
186 |
+
BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
|
187 |
+
# nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
|
188 |
+
elif j == i:
|
189 |
+
fuse_layer.append(None)
|
190 |
+
else:
|
191 |
+
conv3x3s = []
|
192 |
+
for k in range(i - j):
|
193 |
+
if k == i - j - 1:
|
194 |
+
num_outchannels_conv3x3 = num_inchannels[i]
|
195 |
+
conv3x3s.append(nn.Sequential(
|
196 |
+
nn.Conv2d(num_inchannels[j],
|
197 |
+
num_outchannels_conv3x3,
|
198 |
+
3, 2, 1, bias=False),
|
199 |
+
BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM)))
|
200 |
+
else:
|
201 |
+
num_outchannels_conv3x3 = num_inchannels[j]
|
202 |
+
conv3x3s.append(nn.Sequential(
|
203 |
+
nn.Conv2d(num_inchannels[j],
|
204 |
+
num_outchannels_conv3x3,
|
205 |
+
3, 2, 1, bias=False),
|
206 |
+
BatchNorm2d(num_outchannels_conv3x3,
|
207 |
+
momentum=BN_MOMENTUM),
|
208 |
+
nn.ReLU(inplace=True)))
|
209 |
+
fuse_layer.append(nn.Sequential(*conv3x3s))
|
210 |
+
fuse_layers.append(nn.ModuleList(fuse_layer))
|
211 |
+
|
212 |
+
return nn.ModuleList(fuse_layers)
|
213 |
+
|
214 |
+
def get_num_inchannels(self):
|
215 |
+
return self.num_inchannels
|
216 |
+
|
217 |
+
def forward(self, x):
|
218 |
+
if self.num_branches == 1:
|
219 |
+
return [self.branches[0](x[0])]
|
220 |
+
|
221 |
+
for i in range(self.num_branches):
|
222 |
+
x[i] = self.branches[i](x[i])
|
223 |
+
|
224 |
+
x_fuse = []
|
225 |
+
for i in range(len(self.fuse_layers)):
|
226 |
+
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
227 |
+
for j in range(1, self.num_branches):
|
228 |
+
if i == j:
|
229 |
+
y = y + x[j]
|
230 |
+
elif j > i:
|
231 |
+
y = y + F.interpolate(
|
232 |
+
self.fuse_layers[i][j](x[j]),
|
233 |
+
size=[x[i].shape[2], x[i].shape[3]],
|
234 |
+
mode='bilinear')
|
235 |
+
else:
|
236 |
+
y = y + self.fuse_layers[i][j](x[j])
|
237 |
+
x_fuse.append(self.relu(y))
|
238 |
+
|
239 |
+
return x_fuse
|
240 |
+
|
241 |
+
|
242 |
+
blocks_dict = {
|
243 |
+
'BASIC': BasicBlock,
|
244 |
+
'BOTTLENECK': Bottleneck
|
245 |
+
}
|
246 |
+
|
247 |
+
|
248 |
+
class HighResolutionNet(nn.Module):
|
249 |
+
|
250 |
+
def __init__(self, config, **kwargs):
|
251 |
+
self.inplanes = 64
|
252 |
+
extra = config['MODEL']['EXTRA']
|
253 |
+
super(HighResolutionNet, self).__init__()
|
254 |
+
|
255 |
+
# stem net
|
256 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=2, padding=1,
|
257 |
+
bias=False)
|
258 |
+
self.bn1 = BatchNorm2d(self.inplanes, momentum=BN_MOMENTUM)
|
259 |
+
self.conv2 = nn.Conv2d(self.inplanes, self.inplanes, kernel_size=3, stride=2, padding=1,
|
260 |
+
bias=False)
|
261 |
+
self.bn2 = BatchNorm2d(self.inplanes, momentum=BN_MOMENTUM)
|
262 |
+
self.relu = nn.ReLU(inplace=True)
|
263 |
+
self.sf = nn.Softmax(dim=1)
|
264 |
+
self.layer1 = self._make_layer(Bottleneck, 64, 64, 4)
|
265 |
+
|
266 |
+
self.stage2_cfg = extra['STAGE2']
|
267 |
+
num_channels = self.stage2_cfg['NUM_CHANNELS']
|
268 |
+
block = blocks_dict[self.stage2_cfg['BLOCK']]
|
269 |
+
num_channels = [
|
270 |
+
num_channels[i] * block.expansion for i in range(len(num_channels))]
|
271 |
+
self.transition1 = self._make_transition_layer(
|
272 |
+
[256], num_channels)
|
273 |
+
self.stage2, pre_stage_channels = self._make_stage(
|
274 |
+
self.stage2_cfg, num_channels)
|
275 |
+
|
276 |
+
self.stage3_cfg = extra['STAGE3']
|
277 |
+
num_channels = self.stage3_cfg['NUM_CHANNELS']
|
278 |
+
block = blocks_dict[self.stage3_cfg['BLOCK']]
|
279 |
+
num_channels = [
|
280 |
+
num_channels[i] * block.expansion for i in range(len(num_channels))]
|
281 |
+
self.transition2 = self._make_transition_layer(
|
282 |
+
pre_stage_channels, num_channels)
|
283 |
+
self.stage3, pre_stage_channels = self._make_stage(
|
284 |
+
self.stage3_cfg, num_channels)
|
285 |
+
|
286 |
+
self.stage4_cfg = extra['STAGE4']
|
287 |
+
num_channels = self.stage4_cfg['NUM_CHANNELS']
|
288 |
+
block = blocks_dict[self.stage4_cfg['BLOCK']]
|
289 |
+
num_channels = [
|
290 |
+
num_channels[i] * block.expansion for i in range(len(num_channels))]
|
291 |
+
self.transition3 = self._make_transition_layer(
|
292 |
+
pre_stage_channels, num_channels)
|
293 |
+
self.stage4, pre_stage_channels = self._make_stage(
|
294 |
+
self.stage4_cfg, num_channels, multi_scale_output=True)
|
295 |
+
|
296 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
297 |
+
final_inp_channels = sum(pre_stage_channels) + self.inplanes
|
298 |
+
|
299 |
+
self.head = nn.Sequential(nn.Sequential(
|
300 |
+
nn.Conv2d(
|
301 |
+
in_channels=final_inp_channels,
|
302 |
+
out_channels=final_inp_channels,
|
303 |
+
kernel_size=1),
|
304 |
+
BatchNorm2d(final_inp_channels, momentum=BN_MOMENTUM),
|
305 |
+
nn.ReLU(inplace=True),
|
306 |
+
nn.Conv2d(
|
307 |
+
in_channels=final_inp_channels,
|
308 |
+
out_channels=config['MODEL']['NUM_JOINTS'],
|
309 |
+
kernel_size=extra['FINAL_CONV_KERNEL']),
|
310 |
+
nn.Softmax(dim=1)))
|
311 |
+
|
312 |
+
|
313 |
+
|
314 |
+
def _make_head(self, x, x_skip):
|
315 |
+
x = self.upsample(x)
|
316 |
+
x = torch.cat([x, x_skip], dim=1)
|
317 |
+
x = self.head(x)
|
318 |
+
|
319 |
+
return x
|
320 |
+
|
321 |
+
def _make_transition_layer(
|
322 |
+
self, num_channels_pre_layer, num_channels_cur_layer):
|
323 |
+
num_branches_cur = len(num_channels_cur_layer)
|
324 |
+
num_branches_pre = len(num_channels_pre_layer)
|
325 |
+
|
326 |
+
transition_layers = []
|
327 |
+
for i in range(num_branches_cur):
|
328 |
+
if i < num_branches_pre:
|
329 |
+
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
330 |
+
transition_layers.append(nn.Sequential(
|
331 |
+
nn.Conv2d(num_channels_pre_layer[i],
|
332 |
+
num_channels_cur_layer[i],
|
333 |
+
3,
|
334 |
+
1,
|
335 |
+
1,
|
336 |
+
bias=False),
|
337 |
+
BatchNorm2d(
|
338 |
+
num_channels_cur_layer[i], momentum=BN_MOMENTUM),
|
339 |
+
nn.ReLU(inplace=True)))
|
340 |
+
else:
|
341 |
+
transition_layers.append(None)
|
342 |
+
else:
|
343 |
+
conv3x3s = []
|
344 |
+
for j in range(i + 1 - num_branches_pre):
|
345 |
+
inchannels = num_channels_pre_layer[-1]
|
346 |
+
outchannels = num_channels_cur_layer[i] \
|
347 |
+
if j == i - num_branches_pre else inchannels
|
348 |
+
conv3x3s.append(nn.Sequential(
|
349 |
+
nn.Conv2d(
|
350 |
+
inchannels, outchannels, 3, 2, 1, bias=False),
|
351 |
+
BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
|
352 |
+
nn.ReLU(inplace=True)))
|
353 |
+
transition_layers.append(nn.Sequential(*conv3x3s))
|
354 |
+
|
355 |
+
return nn.ModuleList(transition_layers)
|
356 |
+
|
357 |
+
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
358 |
+
downsample = None
|
359 |
+
if stride != 1 or inplanes != planes * block.expansion:
|
360 |
+
downsample = nn.Sequential(
|
361 |
+
nn.Conv2d(inplanes, planes * block.expansion,
|
362 |
+
kernel_size=1, stride=stride, bias=False),
|
363 |
+
BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
|
364 |
+
)
|
365 |
+
|
366 |
+
layers = []
|
367 |
+
layers.append(block(inplanes, planes, stride, downsample))
|
368 |
+
inplanes = planes * block.expansion
|
369 |
+
for i in range(1, blocks):
|
370 |
+
layers.append(block(inplanes, planes))
|
371 |
+
|
372 |
+
return nn.Sequential(*layers)
|
373 |
+
|
374 |
+
def _make_stage(self, layer_config, num_inchannels,
|
375 |
+
multi_scale_output=True):
|
376 |
+
num_modules = layer_config['NUM_MODULES']
|
377 |
+
num_branches = layer_config['NUM_BRANCHES']
|
378 |
+
num_blocks = layer_config['NUM_BLOCKS']
|
379 |
+
num_channels = layer_config['NUM_CHANNELS']
|
380 |
+
block = blocks_dict[layer_config['BLOCK']]
|
381 |
+
fuse_method = layer_config['FUSE_METHOD']
|
382 |
+
|
383 |
+
modules = []
|
384 |
+
for i in range(num_modules):
|
385 |
+
# multi_scale_output is only used last module
|
386 |
+
if not multi_scale_output and i == num_modules - 1:
|
387 |
+
reset_multi_scale_output = False
|
388 |
+
else:
|
389 |
+
reset_multi_scale_output = True
|
390 |
+
modules.append(
|
391 |
+
HighResolutionModule(num_branches,
|
392 |
+
block,
|
393 |
+
num_blocks,
|
394 |
+
num_inchannels,
|
395 |
+
num_channels,
|
396 |
+
fuse_method,
|
397 |
+
reset_multi_scale_output)
|
398 |
+
)
|
399 |
+
num_inchannels = modules[-1].get_num_inchannels()
|
400 |
+
|
401 |
+
return nn.Sequential(*modules), num_inchannels
|
402 |
+
|
403 |
+
def forward(self, x):
|
404 |
+
# h, w = x.size(2), x.size(3)
|
405 |
+
x = self.conv1(x)
|
406 |
+
x_skip = x.clone()
|
407 |
+
x = self.bn1(x)
|
408 |
+
x = self.relu(x)
|
409 |
+
x = self.conv2(x)
|
410 |
+
x = self.bn2(x)
|
411 |
+
x = self.relu(x)
|
412 |
+
x = self.layer1(x)
|
413 |
+
|
414 |
+
x_list = []
|
415 |
+
for i in range(self.stage2_cfg['NUM_BRANCHES']):
|
416 |
+
if self.transition1[i] is not None:
|
417 |
+
x_list.append(self.transition1[i](x))
|
418 |
+
else:
|
419 |
+
x_list.append(x)
|
420 |
+
y_list = self.stage2(x_list)
|
421 |
+
|
422 |
+
x_list = []
|
423 |
+
for i in range(self.stage3_cfg['NUM_BRANCHES']):
|
424 |
+
if self.transition2[i] is not None:
|
425 |
+
x_list.append(self.transition2[i](y_list[-1]))
|
426 |
+
else:
|
427 |
+
x_list.append(y_list[i])
|
428 |
+
y_list = self.stage3(x_list)
|
429 |
+
|
430 |
+
x_list = []
|
431 |
+
for i in range(self.stage4_cfg['NUM_BRANCHES']):
|
432 |
+
if self.transition3[i] is not None:
|
433 |
+
x_list.append(self.transition3[i](y_list[-1]))
|
434 |
+
else:
|
435 |
+
x_list.append(y_list[i])
|
436 |
+
x = self.stage4(x_list)
|
437 |
+
|
438 |
+
# Head Part
|
439 |
+
height, width = x[0].size(2), x[0].size(3)
|
440 |
+
x1 = F.interpolate(x[1], size=(height, width), mode='bilinear', align_corners=False)
|
441 |
+
x2 = F.interpolate(x[2], size=(height, width), mode='bilinear', align_corners=False)
|
442 |
+
x3 = F.interpolate(x[3], size=(height, width), mode='bilinear', align_corners=False)
|
443 |
+
x = torch.cat([x[0], x1, x2, x3], 1)
|
444 |
+
x = self._make_head(x, x_skip)
|
445 |
+
|
446 |
+
return x
|
447 |
+
|
448 |
+
def init_weights(self, pretrained=''):
|
449 |
+
logger.info('=> init weights from normal distribution')
|
450 |
+
for m in self.modules():
|
451 |
+
if isinstance(m, nn.Conv2d):
|
452 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
453 |
+
#nn.init.normal_(m.weight, std=0.001)
|
454 |
+
#nn.init.constant_(m.bias, 0)
|
455 |
+
elif isinstance(m, nn.BatchNorm2d):
|
456 |
+
nn.init.constant_(m.weight, 1)
|
457 |
+
nn.init.constant_(m.bias, 0)
|
458 |
+
if pretrained != '':
|
459 |
+
if os.path.isfile(pretrained):
|
460 |
+
pretrained_dict = torch.load(pretrained)
|
461 |
+
logger.info('=> loading pretrained model {}'.format(pretrained))
|
462 |
+
print('=> loading pretrained model {}'.format(pretrained))
|
463 |
+
model_dict = self.state_dict()
|
464 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items()
|
465 |
+
if k in model_dict.keys()}
|
466 |
+
for k, _ in pretrained_dict.items():
|
467 |
+
logger.info(
|
468 |
+
'=> loading {} pretrained model {}'.format(k, pretrained))
|
469 |
+
#print('=> loading {} pretrained model {}'.format(k, pretrained))
|
470 |
+
model_dict.update(pretrained_dict)
|
471 |
+
self.load_state_dict(model_dict)
|
472 |
+
else:
|
473 |
+
sys.exit(f'Weights {pretrained} not found.')
|
474 |
+
|
475 |
+
|
476 |
+
def get_cls_net(config, pretrained='', **kwargs):
|
477 |
+
model = HighResolutionNet(config, **kwargs)
|
478 |
+
model.init_weights(pretrained)
|
479 |
+
return model
|
model/cls_hrnet_l.py
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import os
|
6 |
+
import logging
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
BatchNorm2d = nn.BatchNorm2d
|
14 |
+
BN_MOMENTUM = 0.1
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
19 |
+
"""3x3 convolution with padding"""
|
20 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
21 |
+
stride=stride, padding=1, bias=False)
|
22 |
+
|
23 |
+
|
24 |
+
class BasicBlock(nn.Module):
|
25 |
+
expansion = 1
|
26 |
+
|
27 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
28 |
+
super(BasicBlock, self).__init__()
|
29 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
30 |
+
self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
31 |
+
self.relu = nn.ReLU(inplace=True)
|
32 |
+
self.conv2 = conv3x3(planes, planes)
|
33 |
+
self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
34 |
+
self.downsample = downsample
|
35 |
+
self.stride = stride
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
residual = x
|
39 |
+
|
40 |
+
out = self.conv1(x)
|
41 |
+
out = self.bn1(out)
|
42 |
+
out = self.relu(out)
|
43 |
+
|
44 |
+
out = self.conv2(out)
|
45 |
+
out = self.bn2(out)
|
46 |
+
|
47 |
+
if self.downsample is not None:
|
48 |
+
residual = self.downsample(x)
|
49 |
+
|
50 |
+
out += residual
|
51 |
+
out = self.relu(out)
|
52 |
+
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
class Bottleneck(nn.Module):
|
57 |
+
expansion = 4
|
58 |
+
|
59 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
60 |
+
super(Bottleneck, self).__init__()
|
61 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
62 |
+
self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
63 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
64 |
+
padding=1, bias=False)
|
65 |
+
self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
66 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
|
67 |
+
bias=False)
|
68 |
+
self.bn3 = BatchNorm2d(planes * self.expansion,
|
69 |
+
momentum=BN_MOMENTUM)
|
70 |
+
self.relu = nn.ReLU(inplace=True)
|
71 |
+
self.downsample = downsample
|
72 |
+
self.stride = stride
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
residual = x
|
76 |
+
|
77 |
+
out = self.conv1(x)
|
78 |
+
out = self.bn1(out)
|
79 |
+
out = self.relu(out)
|
80 |
+
|
81 |
+
out = self.conv2(out)
|
82 |
+
out = self.bn2(out)
|
83 |
+
out = self.relu(out)
|
84 |
+
|
85 |
+
out = self.conv3(out)
|
86 |
+
out = self.bn3(out)
|
87 |
+
|
88 |
+
if self.downsample is not None:
|
89 |
+
residual = self.downsample(x)
|
90 |
+
|
91 |
+
out += residual
|
92 |
+
out = self.relu(out)
|
93 |
+
|
94 |
+
return out
|
95 |
+
|
96 |
+
|
97 |
+
class HighResolutionModule(nn.Module):
|
98 |
+
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
|
99 |
+
num_channels, fuse_method, multi_scale_output=True):
|
100 |
+
super(HighResolutionModule, self).__init__()
|
101 |
+
self._check_branches(
|
102 |
+
num_branches, blocks, num_blocks, num_inchannels, num_channels)
|
103 |
+
|
104 |
+
self.num_inchannels = num_inchannels
|
105 |
+
self.fuse_method = fuse_method
|
106 |
+
self.num_branches = num_branches
|
107 |
+
|
108 |
+
self.multi_scale_output = multi_scale_output
|
109 |
+
|
110 |
+
self.branches = self._make_branches(
|
111 |
+
num_branches, blocks, num_blocks, num_channels)
|
112 |
+
self.fuse_layers = self._make_fuse_layers()
|
113 |
+
self.relu = nn.ReLU(inplace=True)
|
114 |
+
|
115 |
+
def _check_branches(self, num_branches, blocks, num_blocks,
|
116 |
+
num_inchannels, num_channels):
|
117 |
+
if num_branches != len(num_blocks):
|
118 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
|
119 |
+
num_branches, len(num_blocks))
|
120 |
+
logger.error(error_msg)
|
121 |
+
raise ValueError(error_msg)
|
122 |
+
|
123 |
+
if num_branches != len(num_channels):
|
124 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
|
125 |
+
num_branches, len(num_channels))
|
126 |
+
logger.error(error_msg)
|
127 |
+
raise ValueError(error_msg)
|
128 |
+
|
129 |
+
if num_branches != len(num_inchannels):
|
130 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
|
131 |
+
num_branches, len(num_inchannels))
|
132 |
+
logger.error(error_msg)
|
133 |
+
raise ValueError(error_msg)
|
134 |
+
|
135 |
+
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
|
136 |
+
stride=1):
|
137 |
+
downsample = None
|
138 |
+
if stride != 1 or \
|
139 |
+
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
|
140 |
+
downsample = nn.Sequential(
|
141 |
+
nn.Conv2d(self.num_inchannels[branch_index],
|
142 |
+
num_channels[branch_index] * block.expansion,
|
143 |
+
kernel_size=1, stride=stride, bias=False),
|
144 |
+
BatchNorm2d(num_channels[branch_index] * block.expansion,
|
145 |
+
momentum=BN_MOMENTUM),
|
146 |
+
)
|
147 |
+
|
148 |
+
layers = []
|
149 |
+
layers.append(block(self.num_inchannels[branch_index],
|
150 |
+
num_channels[branch_index], stride, downsample))
|
151 |
+
self.num_inchannels[branch_index] = \
|
152 |
+
num_channels[branch_index] * block.expansion
|
153 |
+
for i in range(1, num_blocks[branch_index]):
|
154 |
+
layers.append(block(self.num_inchannels[branch_index],
|
155 |
+
num_channels[branch_index]))
|
156 |
+
|
157 |
+
return nn.Sequential(*layers)
|
158 |
+
|
159 |
+
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
160 |
+
branches = []
|
161 |
+
|
162 |
+
for i in range(num_branches):
|
163 |
+
branches.append(
|
164 |
+
self._make_one_branch(i, block, num_blocks, num_channels))
|
165 |
+
|
166 |
+
return nn.ModuleList(branches)
|
167 |
+
|
168 |
+
def _make_fuse_layers(self):
|
169 |
+
if self.num_branches == 1:
|
170 |
+
return None
|
171 |
+
|
172 |
+
num_branches = self.num_branches
|
173 |
+
num_inchannels = self.num_inchannels
|
174 |
+
fuse_layers = []
|
175 |
+
for i in range(num_branches if self.multi_scale_output else 1):
|
176 |
+
fuse_layer = []
|
177 |
+
for j in range(num_branches):
|
178 |
+
if j > i:
|
179 |
+
fuse_layer.append(nn.Sequential(
|
180 |
+
nn.Conv2d(num_inchannels[j],
|
181 |
+
num_inchannels[i],
|
182 |
+
1,
|
183 |
+
1,
|
184 |
+
0,
|
185 |
+
bias=False),
|
186 |
+
BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
|
187 |
+
# nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
|
188 |
+
elif j == i:
|
189 |
+
fuse_layer.append(None)
|
190 |
+
else:
|
191 |
+
conv3x3s = []
|
192 |
+
for k in range(i - j):
|
193 |
+
if k == i - j - 1:
|
194 |
+
num_outchannels_conv3x3 = num_inchannels[i]
|
195 |
+
conv3x3s.append(nn.Sequential(
|
196 |
+
nn.Conv2d(num_inchannels[j],
|
197 |
+
num_outchannels_conv3x3,
|
198 |
+
3, 2, 1, bias=False),
|
199 |
+
BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM)))
|
200 |
+
else:
|
201 |
+
num_outchannels_conv3x3 = num_inchannels[j]
|
202 |
+
conv3x3s.append(nn.Sequential(
|
203 |
+
nn.Conv2d(num_inchannels[j],
|
204 |
+
num_outchannels_conv3x3,
|
205 |
+
3, 2, 1, bias=False),
|
206 |
+
BatchNorm2d(num_outchannels_conv3x3,
|
207 |
+
momentum=BN_MOMENTUM),
|
208 |
+
nn.ReLU(inplace=True)))
|
209 |
+
fuse_layer.append(nn.Sequential(*conv3x3s))
|
210 |
+
fuse_layers.append(nn.ModuleList(fuse_layer))
|
211 |
+
|
212 |
+
return nn.ModuleList(fuse_layers)
|
213 |
+
|
214 |
+
def get_num_inchannels(self):
|
215 |
+
return self.num_inchannels
|
216 |
+
|
217 |
+
def forward(self, x):
|
218 |
+
if self.num_branches == 1:
|
219 |
+
return [self.branches[0](x[0])]
|
220 |
+
|
221 |
+
for i in range(self.num_branches):
|
222 |
+
x[i] = self.branches[i](x[i])
|
223 |
+
|
224 |
+
x_fuse = []
|
225 |
+
for i in range(len(self.fuse_layers)):
|
226 |
+
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
227 |
+
for j in range(1, self.num_branches):
|
228 |
+
if i == j:
|
229 |
+
y = y + x[j]
|
230 |
+
elif j > i:
|
231 |
+
y = y + F.interpolate(
|
232 |
+
self.fuse_layers[i][j](x[j]),
|
233 |
+
size=[x[i].shape[2], x[i].shape[3]],
|
234 |
+
mode='bilinear')
|
235 |
+
else:
|
236 |
+
y = y + self.fuse_layers[i][j](x[j])
|
237 |
+
x_fuse.append(self.relu(y))
|
238 |
+
|
239 |
+
return x_fuse
|
240 |
+
|
241 |
+
|
242 |
+
blocks_dict = {
|
243 |
+
'BASIC': BasicBlock,
|
244 |
+
'BOTTLENECK': Bottleneck
|
245 |
+
}
|
246 |
+
|
247 |
+
|
248 |
+
class HighResolutionNet(nn.Module):
|
249 |
+
|
250 |
+
def __init__(self, config, **kwargs):
|
251 |
+
self.inplanes = 64
|
252 |
+
extra = config['MODEL']['EXTRA']
|
253 |
+
super(HighResolutionNet, self).__init__()
|
254 |
+
|
255 |
+
# stem net
|
256 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=2, padding=1,
|
257 |
+
bias=False)
|
258 |
+
self.bn1 = BatchNorm2d(self.inplanes, momentum=BN_MOMENTUM)
|
259 |
+
self.conv2 = nn.Conv2d(self.inplanes, self.inplanes, kernel_size=3, stride=2, padding=1,
|
260 |
+
bias=False)
|
261 |
+
self.bn2 = BatchNorm2d(self.inplanes, momentum=BN_MOMENTUM)
|
262 |
+
self.relu = nn.ReLU(inplace=True)
|
263 |
+
self.sf = nn.Softmax(dim=1)
|
264 |
+
self.layer1 = self._make_layer(Bottleneck, self.inplanes, self.inplanes, 4)
|
265 |
+
|
266 |
+
self.stage2_cfg = extra['STAGE2']
|
267 |
+
num_channels = self.stage2_cfg['NUM_CHANNELS']
|
268 |
+
block = blocks_dict[self.stage2_cfg['BLOCK']]
|
269 |
+
num_channels = [
|
270 |
+
num_channels[i] * block.expansion for i in range(len(num_channels))]
|
271 |
+
self.transition1 = self._make_transition_layer(
|
272 |
+
[256], num_channels)
|
273 |
+
self.stage2, pre_stage_channels = self._make_stage(
|
274 |
+
self.stage2_cfg, num_channels)
|
275 |
+
|
276 |
+
self.stage3_cfg = extra['STAGE3']
|
277 |
+
num_channels = self.stage3_cfg['NUM_CHANNELS']
|
278 |
+
block = blocks_dict[self.stage3_cfg['BLOCK']]
|
279 |
+
num_channels = [
|
280 |
+
num_channels[i] * block.expansion for i in range(len(num_channels))]
|
281 |
+
self.transition2 = self._make_transition_layer(
|
282 |
+
pre_stage_channels, num_channels)
|
283 |
+
self.stage3, pre_stage_channels = self._make_stage(
|
284 |
+
self.stage3_cfg, num_channels)
|
285 |
+
|
286 |
+
self.stage4_cfg = extra['STAGE4']
|
287 |
+
num_channels = self.stage4_cfg['NUM_CHANNELS']
|
288 |
+
block = blocks_dict[self.stage4_cfg['BLOCK']]
|
289 |
+
num_channels = [
|
290 |
+
num_channels[i] * block.expansion for i in range(len(num_channels))]
|
291 |
+
self.transition3 = self._make_transition_layer(
|
292 |
+
pre_stage_channels, num_channels)
|
293 |
+
self.stage4, pre_stage_channels = self._make_stage(
|
294 |
+
self.stage4_cfg, num_channels, multi_scale_output=True)
|
295 |
+
|
296 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
297 |
+
final_inp_channels = sum(pre_stage_channels) + self.inplanes
|
298 |
+
|
299 |
+
self.head = nn.Sequential(nn.Sequential(
|
300 |
+
nn.Conv2d(
|
301 |
+
in_channels=final_inp_channels,
|
302 |
+
out_channels=final_inp_channels,
|
303 |
+
kernel_size=1),
|
304 |
+
BatchNorm2d(final_inp_channels, momentum=BN_MOMENTUM),
|
305 |
+
nn.ReLU(inplace=True),
|
306 |
+
nn.Conv2d(
|
307 |
+
in_channels=final_inp_channels,
|
308 |
+
out_channels=config['MODEL']['NUM_JOINTS'],
|
309 |
+
kernel_size=extra['FINAL_CONV_KERNEL']),
|
310 |
+
nn.Sigmoid()))
|
311 |
+
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
+
def _make_head(self, x, x_skip):
|
316 |
+
x = self.upsample(x)
|
317 |
+
x = torch.cat([x, x_skip], dim=1)
|
318 |
+
x = self.head(x)
|
319 |
+
|
320 |
+
return x
|
321 |
+
|
322 |
+
def _make_transition_layer(
|
323 |
+
self, num_channels_pre_layer, num_channels_cur_layer):
|
324 |
+
num_branches_cur = len(num_channels_cur_layer)
|
325 |
+
num_branches_pre = len(num_channels_pre_layer)
|
326 |
+
|
327 |
+
transition_layers = []
|
328 |
+
for i in range(num_branches_cur):
|
329 |
+
if i < num_branches_pre:
|
330 |
+
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
331 |
+
transition_layers.append(nn.Sequential(
|
332 |
+
nn.Conv2d(num_channels_pre_layer[i],
|
333 |
+
num_channels_cur_layer[i],
|
334 |
+
3,
|
335 |
+
1,
|
336 |
+
1,
|
337 |
+
bias=False),
|
338 |
+
BatchNorm2d(
|
339 |
+
num_channels_cur_layer[i], momentum=BN_MOMENTUM),
|
340 |
+
nn.ReLU(inplace=True)))
|
341 |
+
else:
|
342 |
+
transition_layers.append(None)
|
343 |
+
else:
|
344 |
+
conv3x3s = []
|
345 |
+
for j in range(i + 1 - num_branches_pre):
|
346 |
+
inchannels = num_channels_pre_layer[-1]
|
347 |
+
outchannels = num_channels_cur_layer[i] \
|
348 |
+
if j == i - num_branches_pre else inchannels
|
349 |
+
conv3x3s.append(nn.Sequential(
|
350 |
+
nn.Conv2d(
|
351 |
+
inchannels, outchannels, 3, 2, 1, bias=False),
|
352 |
+
BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
|
353 |
+
nn.ReLU(inplace=True)))
|
354 |
+
transition_layers.append(nn.Sequential(*conv3x3s))
|
355 |
+
|
356 |
+
return nn.ModuleList(transition_layers)
|
357 |
+
|
358 |
+
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
359 |
+
downsample = None
|
360 |
+
if stride != 1 or inplanes != planes * block.expansion:
|
361 |
+
downsample = nn.Sequential(
|
362 |
+
nn.Conv2d(inplanes, planes * block.expansion,
|
363 |
+
kernel_size=1, stride=stride, bias=False),
|
364 |
+
BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
|
365 |
+
)
|
366 |
+
|
367 |
+
layers = []
|
368 |
+
layers.append(block(inplanes, planes, stride, downsample))
|
369 |
+
inplanes = planes * block.expansion
|
370 |
+
for i in range(1, blocks):
|
371 |
+
layers.append(block(inplanes, planes))
|
372 |
+
|
373 |
+
return nn.Sequential(*layers)
|
374 |
+
|
375 |
+
def _make_stage(self, layer_config, num_inchannels,
|
376 |
+
multi_scale_output=True):
|
377 |
+
num_modules = layer_config['NUM_MODULES']
|
378 |
+
num_branches = layer_config['NUM_BRANCHES']
|
379 |
+
num_blocks = layer_config['NUM_BLOCKS']
|
380 |
+
num_channels = layer_config['NUM_CHANNELS']
|
381 |
+
block = blocks_dict[layer_config['BLOCK']]
|
382 |
+
fuse_method = layer_config['FUSE_METHOD']
|
383 |
+
|
384 |
+
modules = []
|
385 |
+
for i in range(num_modules):
|
386 |
+
# multi_scale_output is only used last module
|
387 |
+
if not multi_scale_output and i == num_modules - 1:
|
388 |
+
reset_multi_scale_output = False
|
389 |
+
else:
|
390 |
+
reset_multi_scale_output = True
|
391 |
+
modules.append(
|
392 |
+
HighResolutionModule(num_branches,
|
393 |
+
block,
|
394 |
+
num_blocks,
|
395 |
+
num_inchannels,
|
396 |
+
num_channels,
|
397 |
+
fuse_method,
|
398 |
+
reset_multi_scale_output)
|
399 |
+
)
|
400 |
+
num_inchannels = modules[-1].get_num_inchannels()
|
401 |
+
|
402 |
+
return nn.Sequential(*modules), num_inchannels
|
403 |
+
|
404 |
+
def forward(self, x):
|
405 |
+
# h, w = x.size(2), x.size(3)
|
406 |
+
x = self.conv1(x)
|
407 |
+
x_skip = x.clone()
|
408 |
+
x = self.bn1(x)
|
409 |
+
x = self.relu(x)
|
410 |
+
x = self.conv2(x)
|
411 |
+
x = self.bn2(x)
|
412 |
+
x = self.relu(x)
|
413 |
+
x = self.layer1(x)
|
414 |
+
|
415 |
+
x_list = []
|
416 |
+
for i in range(self.stage2_cfg['NUM_BRANCHES']):
|
417 |
+
if self.transition1[i] is not None:
|
418 |
+
x_list.append(self.transition1[i](x))
|
419 |
+
else:
|
420 |
+
x_list.append(x)
|
421 |
+
y_list = self.stage2(x_list)
|
422 |
+
|
423 |
+
x_list = []
|
424 |
+
for i in range(self.stage3_cfg['NUM_BRANCHES']):
|
425 |
+
if self.transition2[i] is not None:
|
426 |
+
x_list.append(self.transition2[i](y_list[-1]))
|
427 |
+
else:
|
428 |
+
x_list.append(y_list[i])
|
429 |
+
y_list = self.stage3(x_list)
|
430 |
+
|
431 |
+
x_list = []
|
432 |
+
for i in range(self.stage4_cfg['NUM_BRANCHES']):
|
433 |
+
if self.transition3[i] is not None:
|
434 |
+
x_list.append(self.transition3[i](y_list[-1]))
|
435 |
+
else:
|
436 |
+
x_list.append(y_list[i])
|
437 |
+
x = self.stage4(x_list)
|
438 |
+
|
439 |
+
# Head Part
|
440 |
+
height, width = x[0].size(2), x[0].size(3)
|
441 |
+
x1 = F.interpolate(x[1], size=(height, width), mode='bilinear', align_corners=False)
|
442 |
+
x2 = F.interpolate(x[2], size=(height, width), mode='bilinear', align_corners=False)
|
443 |
+
x3 = F.interpolate(x[3], size=(height, width), mode='bilinear', align_corners=False)
|
444 |
+
x = torch.cat([x[0], x1, x2, x3], 1)
|
445 |
+
x = self._make_head(x, x_skip)
|
446 |
+
|
447 |
+
return x
|
448 |
+
|
449 |
+
def init_weights(self, pretrained=''):
|
450 |
+
logger.info('=> init weights from normal distribution')
|
451 |
+
for m in self.modules():
|
452 |
+
if isinstance(m, nn.Conv2d):
|
453 |
+
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
454 |
+
nn.init.normal_(m.weight, std=0.001)
|
455 |
+
# nn.init.constant_(m.bias, 0)
|
456 |
+
elif isinstance(m, nn.BatchNorm2d):
|
457 |
+
nn.init.constant_(m.weight, 1)
|
458 |
+
nn.init.constant_(m.bias, 0)
|
459 |
+
if os.path.isfile(pretrained):
|
460 |
+
pretrained_dict = torch.load(pretrained)
|
461 |
+
logger.info('=> loading pretrained model {}'.format(pretrained))
|
462 |
+
print('=> loading pretrained model {}'.format(pretrained))
|
463 |
+
model_dict = self.state_dict()
|
464 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items()
|
465 |
+
if k in model_dict.keys()}
|
466 |
+
for k, _ in pretrained_dict.items():
|
467 |
+
logger.info(
|
468 |
+
'=> loading {} pretrained model {}'.format(k, pretrained))
|
469 |
+
#print('=> loading {} pretrained model {}'.format(k, pretrained))
|
470 |
+
model_dict.update(pretrained_dict)
|
471 |
+
self.load_state_dict(model_dict)
|
472 |
+
|
473 |
+
|
474 |
+
def get_cls_net(config, pretrained='', **kwargs):
|
475 |
+
model = HighResolutionNet(config, **kwargs)
|
476 |
+
model.init_weights(pretrained)
|
477 |
+
return model
|
478 |
+
|
model/dataloader.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import glob
|
5 |
+
import json
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import torchvision.transforms as T
|
9 |
+
import torchvision.transforms.functional as f
|
10 |
+
|
11 |
+
from torchvision.transforms import v2
|
12 |
+
from torch.utils.data import Dataset
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
from utils.utils_keypoints import KeypointsDB
|
16 |
+
from utils.utils_keypointsWC import KeypointsWCDB
|
17 |
+
|
18 |
+
|
19 |
+
class SoccerNetCalibrationDataset(Dataset):
|
20 |
+
|
21 |
+
def __init__(self, root_dir, split, transform, main_cam_only=True):
|
22 |
+
|
23 |
+
self.root_dir = root_dir
|
24 |
+
self.split = split
|
25 |
+
self.transform = transform
|
26 |
+
|
27 |
+
self.match_info = json.load(open(root_dir + split + '/match_info.json'))
|
28 |
+
self.files = self.get_image_files(rate=1)
|
29 |
+
|
30 |
+
if main_cam_only:
|
31 |
+
self.get_main_camera()
|
32 |
+
|
33 |
+
def get_image_files(self, rate=3):
|
34 |
+
files = glob.glob(os.path.join(self.root_dir + self.split, "*.jpg"))
|
35 |
+
files.sort()
|
36 |
+
if rate > 1:
|
37 |
+
files = files[::rate]
|
38 |
+
return files
|
39 |
+
|
40 |
+
|
41 |
+
def __len__(self):
|
42 |
+
return len(self.files)
|
43 |
+
|
44 |
+
def __getitem__(self, idx):
|
45 |
+
if torch.is_tensor(idx):
|
46 |
+
idx = idx.tolist()
|
47 |
+
|
48 |
+
img_name = self.files[idx]
|
49 |
+
image = Image.open(img_name)
|
50 |
+
data = json.load(open(img_name.split('.')[0] + ".json"))
|
51 |
+
data = self.correct_labels(data)
|
52 |
+
sample = self.transform({'image': image, 'data': data})
|
53 |
+
|
54 |
+
img_db = KeypointsDB(sample['data'], sample['image'])
|
55 |
+
target, mask = img_db.get_tensor_w_mask()
|
56 |
+
image = sample['image']
|
57 |
+
|
58 |
+
return image, torch.from_numpy(target).float(), torch.from_numpy(mask).float()
|
59 |
+
|
60 |
+
|
61 |
+
def get_main_camera(self):
|
62 |
+
self.files = [file for file in self.files if int(self.match_info[file.split('/')[-1]]['ms_time']) == \
|
63 |
+
int(self.match_info[file.split('/')[-1]]['replay_time'])]
|
64 |
+
|
65 |
+
def correct_labels(self, data):
|
66 |
+
if 'Goal left post left' in data.keys():
|
67 |
+
data['Goal left post left '] = copy.deepcopy(data['Goal left post left'])
|
68 |
+
del data['Goal left post left']
|
69 |
+
|
70 |
+
return data
|
71 |
+
|
72 |
+
|
73 |
+
class WorldCup2014Dataset(Dataset):
|
74 |
+
|
75 |
+
def __init__(self, root_dir, split, transform):
|
76 |
+
self.root_dir = root_dir
|
77 |
+
self.split = split
|
78 |
+
self.transform = transform
|
79 |
+
assert self.split in ['train_val', 'test'], f'unknown dataset type {self.split}'
|
80 |
+
|
81 |
+
self.files = glob.glob(os.path.join(self.root_dir + self.split, "*.jpg"))
|
82 |
+
self.homographies = glob.glob(os.path.join(self.root_dir + self.split, "*.homographyMatrix"))
|
83 |
+
self.num_samples = len(self.files)
|
84 |
+
|
85 |
+
self.files.sort()
|
86 |
+
self.homographies.sort()
|
87 |
+
|
88 |
+
def __len__(self):
|
89 |
+
return self.num_samples
|
90 |
+
|
91 |
+
def __getitem__(self, idx):
|
92 |
+
image = self.get_image_by_index(idx)
|
93 |
+
homography = self.get_homography_by_index(idx)
|
94 |
+
img_db = KeypointsWCDB(image, homography, (960,540))
|
95 |
+
target, mask = img_db.get_tensor_w_mask()
|
96 |
+
|
97 |
+
sample = self.transform({'image': image, 'target': target, 'mask': mask})
|
98 |
+
|
99 |
+
return sample['image'], sample['target'], sample['mask']
|
100 |
+
|
101 |
+
def convert_homography_WC14GT_to_SN(self, H):
|
102 |
+
T = np.eye(3)
|
103 |
+
#T[0, -1] = -115 / 2
|
104 |
+
#T[1, -1] = -74 / 2
|
105 |
+
yard2meter = 0.9144
|
106 |
+
S = np.eye(3)
|
107 |
+
S[0, 0] = yard2meter
|
108 |
+
S[1, 1] = yard2meter
|
109 |
+
H_SN = S @ (T @ H)
|
110 |
+
|
111 |
+
return H_SN
|
112 |
+
|
113 |
+
def get_image_by_index(self, index):
|
114 |
+
img_file = self.files[index]
|
115 |
+
image = Image.open(img_file)
|
116 |
+
return image
|
117 |
+
|
118 |
+
def get_homography_by_index(self, index):
|
119 |
+
homography_file = self.homographies[index]
|
120 |
+
with open(homography_file, 'r') as file:
|
121 |
+
lines = file.readlines()
|
122 |
+
matrix_elements = []
|
123 |
+
for line in lines:
|
124 |
+
matrix_elements.extend([float(element) for element in line.split()])
|
125 |
+
homography = np.array(matrix_elements).reshape((3, 3))
|
126 |
+
homography = self.convert_homography_WC14GT_to_SN(homography)
|
127 |
+
homography = torch.from_numpy(homography)
|
128 |
+
homography = homography / homography[2:3, 2:3]
|
129 |
+
return homography
|
130 |
+
|
131 |
+
|
132 |
+
class TSWorldCupDataset(Dataset):
|
133 |
+
|
134 |
+
def __init__(self, root_dir, split, transform):
|
135 |
+
self.root_dir = root_dir
|
136 |
+
self.split = split
|
137 |
+
self.transform = transform
|
138 |
+
assert self.split in ['train', 'test'], f'unknown dataset type {self.split}'
|
139 |
+
|
140 |
+
self.files_txt = self.get_txt()
|
141 |
+
|
142 |
+
self.files = self.get_jpg_files()
|
143 |
+
self.homographies = self.get_homographies()
|
144 |
+
self.num_samples = len(self.files)
|
145 |
+
|
146 |
+
self.files.sort()
|
147 |
+
self.homographies.sort()
|
148 |
+
|
149 |
+
def __len__(self):
|
150 |
+
return self.num_samples
|
151 |
+
|
152 |
+
def __getitem__(self, idx):
|
153 |
+
image = self.get_image_by_index(idx)
|
154 |
+
homography = self.get_homography_by_index(idx)
|
155 |
+
img_db = KeypointsWCDB(image, homography, (960,540))
|
156 |
+
target, mask = img_db.get_tensor_w_mask()
|
157 |
+
|
158 |
+
sample = self.transform({'image': image, 'target': target, 'mask': mask})
|
159 |
+
|
160 |
+
return sample['image'], sample['target'], sample['mask']
|
161 |
+
|
162 |
+
|
163 |
+
def get_txt(self):
|
164 |
+
with open(self.root_dir + self.split + '.txt', 'r') as file:
|
165 |
+
lines = file.readlines()
|
166 |
+
lines = [line.strip() for line in lines]
|
167 |
+
return lines
|
168 |
+
|
169 |
+
def get_jpg_files(self):
|
170 |
+
all_jpg_files = []
|
171 |
+
for dir in self.files_txt:
|
172 |
+
full_dir = self.root_dir + "Dataset/80_95/" + dir
|
173 |
+
jpg_files = []
|
174 |
+
for file in os.listdir(full_dir):
|
175 |
+
if file.lower().endswith('.jpg') or file.lower().endswith('.jpeg'):
|
176 |
+
jpg_files.append(os.path.join(full_dir, file))
|
177 |
+
|
178 |
+
all_jpg_files.extend(jpg_files)
|
179 |
+
|
180 |
+
return all_jpg_files
|
181 |
+
|
182 |
+
def get_homographies(self):
|
183 |
+
all_homographies = []
|
184 |
+
for dir in self.files_txt:
|
185 |
+
full_dir = self.root_dir + "Annotations/80_95/" + dir
|
186 |
+
homographies = []
|
187 |
+
for file in os.listdir(full_dir):
|
188 |
+
if file.lower().endswith('.npy'):
|
189 |
+
homographies.append(os.path.join(full_dir, file))
|
190 |
+
|
191 |
+
all_homographies.extend(homographies)
|
192 |
+
|
193 |
+
return all_homographies
|
194 |
+
|
195 |
+
|
196 |
+
def convert_homography_WC14GT_to_SN(self, H):
|
197 |
+
T = np.eye(3)
|
198 |
+
#T[0, -1] = -115 / 2
|
199 |
+
#T[1, -1] = -74 / 2
|
200 |
+
yard2meter = 0.9144
|
201 |
+
S = np.eye(3)
|
202 |
+
S[0, 0] = yard2meter
|
203 |
+
S[1, 1] = yard2meter
|
204 |
+
H_SN = S @ (T @ H)
|
205 |
+
|
206 |
+
return H_SN
|
207 |
+
|
208 |
+
def get_image_by_index(self, index):
|
209 |
+
img_file = self.files[index]
|
210 |
+
image = Image.open(img_file)
|
211 |
+
return image
|
212 |
+
|
213 |
+
def get_homography_by_index(self, index):
|
214 |
+
homography_file = self.homographies[index]
|
215 |
+
homography = np.load(homography_file)
|
216 |
+
homography = self.convert_homography_WC14GT_to_SN(homography)
|
217 |
+
homography = torch.from_numpy(homography)
|
218 |
+
homography = homography / homography[2:3, 2:3]
|
219 |
+
return homography
|
model/dataloader_l.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import glob
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from torchvision.transforms import v2
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from utils.utils_lines import LineKeypointsDB
|
13 |
+
from utils.utils_linesWC import LineKeypointsWCDB
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class SoccerNetCalibrationDataset(Dataset):
|
18 |
+
|
19 |
+
def __init__(self, root_dir, split, transform, main_cam_only=True):
|
20 |
+
|
21 |
+
self.root_dir = root_dir
|
22 |
+
self.split = split
|
23 |
+
self.transform = transform
|
24 |
+
|
25 |
+
#self.match_info = json.load(open(root_dir + split + '/match_info.json'))
|
26 |
+
self.files = glob.glob(os.path.join(self.root_dir + self.split, "*.jpg"))
|
27 |
+
|
28 |
+
if main_cam_only:
|
29 |
+
self.get_main_camera()
|
30 |
+
|
31 |
+
|
32 |
+
def __len__(self):
|
33 |
+
return len(self.files)
|
34 |
+
|
35 |
+
def __getitem__(self, idx):
|
36 |
+
if torch.is_tensor(idx):
|
37 |
+
idx = idx.tolist()
|
38 |
+
|
39 |
+
img_name = self.files[idx]
|
40 |
+
image = Image.open(img_name)
|
41 |
+
data = json.load(open(img_name.split('.')[0] + ".json"))
|
42 |
+
sample = self.transform({'image': image, 'data': data})
|
43 |
+
|
44 |
+
img_db = LineKeypointsDB(sample['data'], sample['image'])
|
45 |
+
target = img_db.get_tensor()
|
46 |
+
|
47 |
+
return sample['image'], torch.from_numpy(target).float()
|
48 |
+
|
49 |
+
def get_main_camera(self):
|
50 |
+
self.files = [file for file in self.files if int(self.match_info[file.split('/')[-1]]['ms_time']) == \
|
51 |
+
int(self.match_info[file.split('/')[-1]]['replay_time'])]
|
52 |
+
|
53 |
+
|
54 |
+
class WorldCup2014Dataset(Dataset):
|
55 |
+
|
56 |
+
def __init__(self, root_dir, split, transform):
|
57 |
+
self.root_dir = root_dir
|
58 |
+
self.split = split
|
59 |
+
self.transform = transform
|
60 |
+
assert self.split in ['train_val', 'test'], f'unknown dataset type {self.split}'
|
61 |
+
|
62 |
+
self.files = glob.glob(os.path.join(self.root_dir + self.split, "*.jpg"))
|
63 |
+
self.homographies = glob.glob(os.path.join(self.root_dir + self.split, "*.homographyMatrix"))
|
64 |
+
self.num_samples = len(self.files)
|
65 |
+
|
66 |
+
self.files.sort()
|
67 |
+
self.homographies.sort()
|
68 |
+
|
69 |
+
def __len__(self):
|
70 |
+
return self.num_samples
|
71 |
+
|
72 |
+
def __getitem__(self, idx):
|
73 |
+
image = self.get_image_by_index(idx)
|
74 |
+
homography = self.get_homography_by_index(idx)
|
75 |
+
img_db = LineKeypointsWCDB(image, homography, (960,540))
|
76 |
+
target, mask = img_db.get_tensor_w_mask()
|
77 |
+
|
78 |
+
sample = self.transform({'image': image, 'target': target, 'mask': mask})
|
79 |
+
|
80 |
+
return sample['image'], sample['target'], sample['mask']
|
81 |
+
|
82 |
+
def convert_homography_WC14GT_to_SN(self, H):
|
83 |
+
T = np.eye(3)
|
84 |
+
#T[0, -1] = -115 / 2
|
85 |
+
#T[1, -1] = -74 / 2
|
86 |
+
yard2meter = 0.9144
|
87 |
+
S = np.eye(3)
|
88 |
+
S[0, 0] = yard2meter
|
89 |
+
S[1, 1] = yard2meter
|
90 |
+
H_SN = S @ (T @ H)
|
91 |
+
|
92 |
+
return H_SN
|
93 |
+
|
94 |
+
def get_image_by_index(self, index):
|
95 |
+
img_file = self.files[index]
|
96 |
+
image = Image.open(img_file)
|
97 |
+
return image
|
98 |
+
|
99 |
+
def get_homography_by_index(self, index):
|
100 |
+
homography_file = self.homographies[index]
|
101 |
+
with open(homography_file, 'r') as file:
|
102 |
+
lines = file.readlines()
|
103 |
+
matrix_elements = []
|
104 |
+
for line in lines:
|
105 |
+
matrix_elements.extend([float(element) for element in line.split()])
|
106 |
+
homography = np.array(matrix_elements).reshape((3, 3))
|
107 |
+
homography = self.convert_homography_WC14GT_to_SN(homography)
|
108 |
+
homography = torch.from_numpy(homography)
|
109 |
+
homography = homography / homography[2:3, 2:3]
|
110 |
+
return homography
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
class TSWorldCupDataset(Dataset):
|
115 |
+
|
116 |
+
def __init__(self, root_dir, split, transform):
|
117 |
+
self.root_dir = root_dir
|
118 |
+
self.split = split
|
119 |
+
self.transform = transform
|
120 |
+
assert self.split in ['train', 'test'], f'unknown dataset type {self.split}'
|
121 |
+
|
122 |
+
self.files_txt = self.get_txt()
|
123 |
+
|
124 |
+
self.files = self.get_jpg_files()
|
125 |
+
self.homographies = self.get_homographies()
|
126 |
+
self.num_samples = len(self.files)
|
127 |
+
|
128 |
+
self.files.sort()
|
129 |
+
self.homographies.sort()
|
130 |
+
|
131 |
+
def __len__(self):
|
132 |
+
return self.num_samples
|
133 |
+
|
134 |
+
def __getitem__(self, idx):
|
135 |
+
image = self.get_image_by_index(idx)
|
136 |
+
homography = self.get_homography_by_index(idx)
|
137 |
+
img_db = LineKeypointsWCDB(image, homography, (960,540))
|
138 |
+
target, mask = img_db.get_tensor_w_mask()
|
139 |
+
|
140 |
+
sample = self.transform({'image': image, 'target': target, 'mask': mask})
|
141 |
+
|
142 |
+
return sample['image'], sample['target'], sample['mask']
|
143 |
+
|
144 |
+
|
145 |
+
def get_txt(self):
|
146 |
+
with open(self.root_dir + self.split + '.txt', 'r') as file:
|
147 |
+
lines = file.readlines()
|
148 |
+
lines = [line.strip() for line in lines]
|
149 |
+
return lines
|
150 |
+
|
151 |
+
def get_jpg_files(self):
|
152 |
+
all_jpg_files = []
|
153 |
+
for dir in self.files_txt:
|
154 |
+
full_dir = self.root_dir + "Dataset/80_95/" + dir
|
155 |
+
jpg_files = []
|
156 |
+
for file in os.listdir(full_dir):
|
157 |
+
if file.lower().endswith('.jpg') or file.lower().endswith('.jpeg'):
|
158 |
+
jpg_files.append(os.path.join(full_dir, file))
|
159 |
+
|
160 |
+
all_jpg_files.extend(jpg_files)
|
161 |
+
|
162 |
+
return all_jpg_files
|
163 |
+
|
164 |
+
def get_homographies(self):
|
165 |
+
all_homographies = []
|
166 |
+
for dir in self.files_txt:
|
167 |
+
full_dir = self.root_dir + "Annotations/80_95/" + dir
|
168 |
+
homographies = []
|
169 |
+
for file in os.listdir(full_dir):
|
170 |
+
if file.lower().endswith('.npy'):
|
171 |
+
homographies.append(os.path.join(full_dir, file))
|
172 |
+
|
173 |
+
all_homographies.extend(homographies)
|
174 |
+
|
175 |
+
return all_homographies
|
176 |
+
|
177 |
+
|
178 |
+
def convert_homography_WC14GT_to_SN(self, H):
|
179 |
+
T = np.eye(3)
|
180 |
+
#T[0, -1] = -115 / 2
|
181 |
+
#T[1, -1] = -74 / 2
|
182 |
+
yard2meter = 0.9144
|
183 |
+
S = np.eye(3)
|
184 |
+
S[0, 0] = yard2meter
|
185 |
+
S[1, 1] = yard2meter
|
186 |
+
H_SN = S @ (T @ H)
|
187 |
+
|
188 |
+
return H_SN
|
189 |
+
|
190 |
+
def get_image_by_index(self, index):
|
191 |
+
img_file = self.files[index]
|
192 |
+
image = Image.open(img_file)
|
193 |
+
return image
|
194 |
+
|
195 |
+
def get_homography_by_index(self, index):
|
196 |
+
homography_file = self.homographies[index]
|
197 |
+
homography = np.load(homography_file)
|
198 |
+
homography = self.convert_homography_WC14GT_to_SN(homography)
|
199 |
+
homography = torch.from_numpy(homography)
|
200 |
+
homography = homography / homography[2:3, 2:3]
|
201 |
+
return homography
|
202 |
+
|
model/losses.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class MSELoss(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super().__init__()
|
8 |
+
self.criterion = nn.MSELoss(reduction='none')
|
9 |
+
|
10 |
+
def forward(self, output, target, mask=None):
|
11 |
+
loss = self.criterion(output, target)
|
12 |
+
if mask is not None:
|
13 |
+
loss = (loss * mask).mean()
|
14 |
+
else:
|
15 |
+
loss = (loss).mean()
|
16 |
+
return loss
|
17 |
+
|
18 |
+
class KLDivLoss(nn.Module):
|
19 |
+
def __init__(self):
|
20 |
+
super().__init__()
|
21 |
+
self.criterion = nn.KLDivLoss(reduction='batchmean')
|
22 |
+
|
23 |
+
def forward(self, output, target, mask=None):
|
24 |
+
if mask is not None:
|
25 |
+
output_masked = output * mask
|
26 |
+
target_masked = target * mask
|
27 |
+
loss = self.criterion(F.log_softmax(output_masked), target_masked)
|
28 |
+
else:
|
29 |
+
loss = self.criterion(F.log_softmax(output), target)
|
30 |
+
return loss
|
31 |
+
|
32 |
+
class HeatmapWeightingMSELoss(nn.Module):
|
33 |
+
|
34 |
+
def __init__(self):
|
35 |
+
super().__init__()
|
36 |
+
self.criterion = nn.MSELoss(reduction='none')
|
37 |
+
|
38 |
+
def forward(self, output, target, mask=None):
|
39 |
+
"""Forward function."""
|
40 |
+
batch_size = output.size(0)
|
41 |
+
num_joints = output.size(1)
|
42 |
+
|
43 |
+
heatmaps_pred = output.reshape(
|
44 |
+
(batch_size, num_joints, -1)).split(1, 1)
|
45 |
+
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
|
46 |
+
|
47 |
+
loss = 0.
|
48 |
+
|
49 |
+
for idx in range(num_joints):
|
50 |
+
heatmap_pred = heatmaps_pred[idx].squeeze(1)
|
51 |
+
heatmap_gt = heatmaps_gt[idx].squeeze(1)
|
52 |
+
"""
|
53 |
+
Set different weight generation functions.
|
54 |
+
weight = heatmap_gt + 1
|
55 |
+
weight = heatmap_gt * 2 + 1
|
56 |
+
weight = heatmap_gt * heatmap_gt + 1
|
57 |
+
weight = torch.exp(heatmap_gt + 1)
|
58 |
+
"""
|
59 |
+
|
60 |
+
if mask is not None:
|
61 |
+
#weight = heatmap_gt * mask[:, idx] + 1
|
62 |
+
weight = torch.exp(heatmap_gt * mask[:, idx] + 1)
|
63 |
+
loss += torch.mean(self.criterion(heatmap_pred * mask[:, idx],
|
64 |
+
heatmap_gt * mask[:, idx]) * weight)
|
65 |
+
else:
|
66 |
+
weight = heatmap_gt + 1
|
67 |
+
loss += torch.mean(self.criterion(heatmap_pred, heatmap_gt) * weight)
|
68 |
+
return loss / (num_joints+1)
|
69 |
+
|
70 |
+
|
71 |
+
class CombMSEAW(nn.Module):
|
72 |
+
def __init__(self, lambda1=1, lambda2=1, alpha=2.1, omega=14, epsilon=1, theta=0.5):
|
73 |
+
super().__init__()
|
74 |
+
# Adaptive wing loss
|
75 |
+
self.lambda1 = lambda1
|
76 |
+
self.lambda2 = lambda2
|
77 |
+
self.criterion1 = nn.MSELoss(reduction='none')
|
78 |
+
self.alpha = alpha
|
79 |
+
self.omega = omega
|
80 |
+
self.epsilon = epsilon
|
81 |
+
self.theta = theta
|
82 |
+
|
83 |
+
|
84 |
+
def forward(self, pred, target, mask=None):
|
85 |
+
loss = 0
|
86 |
+
if mask is not None:
|
87 |
+
pred_masked, target_masked = pred * mask, target * mask
|
88 |
+
loss += self.lambda1 * self.criterion1(pred_masked, target_masked)
|
89 |
+
loss += self.lambda2 * self.adaptive_wing(pred_masked, target_masked)
|
90 |
+
else:
|
91 |
+
loss += self.lambda1 * self.criterion1(pred, target)
|
92 |
+
loss += self.lambda2 * self.adaptive_wing(pred, target)
|
93 |
+
return torch.mean(loss)
|
94 |
+
|
95 |
+
def adaptive_wing(self, pred, target):
|
96 |
+
delta = (target - pred).abs()
|
97 |
+
alpha_t = self.alpha - target
|
98 |
+
A = self.omega * (
|
99 |
+
1 / (1 + torch.pow(self.theta / self.epsilon,
|
100 |
+
alpha_t))) * alpha_t \
|
101 |
+
* (torch.pow(self.theta / self.epsilon,
|
102 |
+
self.alpha - target - 1)) * (1 / self.epsilon)
|
103 |
+
C = self.theta * A - self.omega * torch.log(
|
104 |
+
1 + torch.pow(self.theta / self.epsilon, alpha_t))
|
105 |
+
|
106 |
+
losses = torch.where(delta < self.theta,
|
107 |
+
self.omega * torch.log(
|
108 |
+
1 + torch.pow(delta / self.epsilon, alpha_t)),
|
109 |
+
A * delta - C)
|
110 |
+
return losses
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
class AdaptiveWingLoss(nn.Module):
|
115 |
+
def __init__(self, alpha=2.1, omega=14, epsilon=1, theta=0.5):
|
116 |
+
super().__init__()
|
117 |
+
# Adaptive wing loss
|
118 |
+
self.alpha = alpha
|
119 |
+
self.omega = omega
|
120 |
+
self.epsilon = epsilon
|
121 |
+
self.theta = theta
|
122 |
+
|
123 |
+
def forward(self, pred, target, mask=None):
|
124 |
+
if mask is not None:
|
125 |
+
pred_masked, target_masked = pred * mask, target * mask
|
126 |
+
loss = self.adaptive_wing(pred_masked, target_masked)
|
127 |
+
else:
|
128 |
+
loss = self.adaptive_wing(pred, target)
|
129 |
+
return loss
|
130 |
+
|
131 |
+
def adaptive_wing(self, pred, target):
|
132 |
+
delta = (target - pred).abs()
|
133 |
+
alpha_t = self.alpha - target
|
134 |
+
A = self.omega * (
|
135 |
+
1 / (1 + torch.pow(self.theta / self.epsilon,
|
136 |
+
alpha_t))) * alpha_t \
|
137 |
+
* (torch.pow(self.theta / self.epsilon,
|
138 |
+
self.alpha - target - 1)) * (1 / self.epsilon)
|
139 |
+
C = self.theta * A - self.omega * torch.log(
|
140 |
+
1 + torch.pow(self.theta / self.epsilon, alpha_t))
|
141 |
+
|
142 |
+
losses = torch.where(delta < self.theta,
|
143 |
+
self.omega * torch.log(
|
144 |
+
1 + torch.pow(delta / self.epsilon, alpha_t)),
|
145 |
+
A * delta - C)
|
146 |
+
return torch.mean(losses)
|
147 |
+
|
148 |
+
class GaussianFocalLoss(nn.Module):
|
149 |
+
"""GaussianFocalLoss is a variant of focal loss.
|
150 |
+
More details can be found in the `paper
|
151 |
+
<https://arxiv.org/abs/1808.01244>`_
|
152 |
+
Code is modified from `kp_utils.py
|
153 |
+
<https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/kp_utils.py#L152>`_ # noqa: E501
|
154 |
+
Please notice that the target in GaussianFocalLoss is a gaussian heatmap,
|
155 |
+
not 0/1 binary target.
|
156 |
+
Args:
|
157 |
+
alpha (float): Power of prediction.
|
158 |
+
gamma (float): Power of target for negative samples.
|
159 |
+
reduction (str): Options are "none", "mean" and "sum".
|
160 |
+
loss_weight (float): Loss weight of current loss.
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(self,
|
164 |
+
alpha=2.0,
|
165 |
+
gamma=4.0,
|
166 |
+
reduction='mean',
|
167 |
+
loss_weight=1.0):
|
168 |
+
super(GaussianFocalLoss, self).__init__()
|
169 |
+
self.alpha = alpha
|
170 |
+
self.gamma = gamma
|
171 |
+
self.reduction = reduction
|
172 |
+
self.loss_weight = loss_weight
|
173 |
+
|
174 |
+
def forward(self,
|
175 |
+
pred,
|
176 |
+
target,
|
177 |
+
mask=None,
|
178 |
+
weight=None,
|
179 |
+
avg_factor=None,
|
180 |
+
reduction_override=None):
|
181 |
+
"""Forward function.
|
182 |
+
Args:
|
183 |
+
pred (torch.Tensor): The prediction.
|
184 |
+
target (torch.Tensor): The learning target of the prediction
|
185 |
+
in gaussian distribution.
|
186 |
+
weight (torch.Tensor, optional): The weight of loss for each
|
187 |
+
prediction. Defaults to None.
|
188 |
+
avg_factor (int, optional): Average factor that is used to average
|
189 |
+
the loss. Defaults to None.
|
190 |
+
reduction_override (str, optional): The reduction method used to
|
191 |
+
override the original reduction method of the loss.
|
192 |
+
Defaults to None.
|
193 |
+
"""
|
194 |
+
assert reduction_override in (None, 'none', 'mean', 'sum')
|
195 |
+
reduction = (reduction_override if reduction_override else self.reduction)
|
196 |
+
if mask is not None:
|
197 |
+
pred_masked, target_masked = pred * mask, target * mask
|
198 |
+
loss_reg = self.loss_weight * self.gaussian_focal_loss(pred_masked, target_masked, alpha=self.alpha,
|
199 |
+
gamma=self.gamma)
|
200 |
+
else:
|
201 |
+
loss_reg = self.loss_weight * self.gaussian_focal_loss(pred, target, alpha=self.alpha, gamma=self.gamma)
|
202 |
+
return loss_reg.mean()
|
203 |
+
|
204 |
+
def gaussian_focal_loss(self, pred, gaussian_target, alpha=2.0, gamma=4.0):
|
205 |
+
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_ for targets in gaussian
|
206 |
+
distribution.
|
207 |
+
Args:
|
208 |
+
pred (torch.Tensor): The prediction.
|
209 |
+
gaussian_target (torch.Tensor): The learning target of the prediction
|
210 |
+
in gaussian distribution.
|
211 |
+
alpha (float, optional): A balanced form for Focal Loss.
|
212 |
+
Defaults to 2.0.
|
213 |
+
gamma (float, optional): The gamma for calculating the modulating
|
214 |
+
factor. Defaults to 4.0.
|
215 |
+
"""
|
216 |
+
eps = 1e-12
|
217 |
+
pos_weights = gaussian_target.eq(1)
|
218 |
+
neg_weights = (1 - gaussian_target).pow(gamma)
|
219 |
+
pos_loss = -(pred + eps).log() * (1 - pred).pow(alpha) * pos_weights
|
220 |
+
neg_loss = -(1 - pred + eps).log() * pred.pow(alpha) * neg_weights
|
221 |
+
return pos_loss + neg_loss
|
model/metrics.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import scipy
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from shapely.geometry import Point, Polygon, MultiPoint
|
8 |
+
|
9 |
+
def calculate_metrics(gt, pred, mask, conf_th=0.1, dist_th=5):
|
10 |
+
geometry_mask = (mask[:, :-1] > 0).cpu()
|
11 |
+
|
12 |
+
pred_mask = torch.all((pred[:, :, :, -1] > conf_th), dim=-1)
|
13 |
+
gt_mask = torch.all((gt[:, :, :, -1] > conf_th), dim=-1)
|
14 |
+
|
15 |
+
pred_pos = pred[geometry_mask][:, 0, :]
|
16 |
+
pred_mask = pred_mask[geometry_mask]
|
17 |
+
gt_pos = gt[geometry_mask][:, 0, :]
|
18 |
+
gt_mask = gt_mask[geometry_mask]
|
19 |
+
|
20 |
+
distances = torch.norm(pred_pos - gt_pos, dim=1)
|
21 |
+
|
22 |
+
# Count true positives, false positives, and false negatives based on distance threshold
|
23 |
+
true_positives = ((distances < dist_th) & pred_mask & gt_mask).sum().item()
|
24 |
+
true_negatives = (~pred_mask & ~gt_mask).sum().item()
|
25 |
+
false_positives = ((pred_mask & ~gt_mask) | ((distances >= dist_th) & pred_mask & gt_mask)).sum().item()
|
26 |
+
false_negatives = (~pred_mask & gt_mask).sum().item()
|
27 |
+
|
28 |
+
# Calculate precision, recall, and F1 score
|
29 |
+
accuracy = (true_positives + true_negatives) / geometry_mask.sum().item()
|
30 |
+
precision = true_positives / (true_positives + false_positives + 1e-10)
|
31 |
+
recall = true_positives / (true_positives + false_negatives + 1e-10)
|
32 |
+
f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
|
33 |
+
|
34 |
+
return accuracy, precision, recall, f1
|
35 |
+
|
36 |
+
|
37 |
+
def calculate_metrics_l(gt, pred, conf_th=0.1, dist_th=5):
|
38 |
+
|
39 |
+
pred_pos = pred[:, :, :, :-1]
|
40 |
+
gt_pos = gt[:, :, :, :-1]
|
41 |
+
|
42 |
+
pred_mask = torch.all((pred[:, :, :, -1] > conf_th), dim=-1)
|
43 |
+
gt_mask = torch.all((gt[:, :, :, -1] > conf_th), dim=-1)
|
44 |
+
|
45 |
+
gt_flip = torch.flip(gt_pos, dims=[2])
|
46 |
+
|
47 |
+
distances1 = torch.norm(pred_pos - gt_pos, dim=-1)
|
48 |
+
distances2 = torch.norm(pred_pos - gt_flip, dim=-1)
|
49 |
+
|
50 |
+
distances1_bool = torch.all((distances1 < dist_th), dim=-1)
|
51 |
+
distances2_bool = torch.all((distances2 < dist_th), dim=-1)
|
52 |
+
|
53 |
+
# Count true positives, false positives, and false negatives based on distance threshold
|
54 |
+
true_positives = ((distances1_bool | distances2_bool) & pred_mask & gt_mask).sum().item()
|
55 |
+
true_negatives = (~pred_mask & ~gt_mask).sum().item()
|
56 |
+
false_positives = (
|
57 |
+
(pred_mask & ~gt_mask) | ((~distances1_bool & ~distances2_bool) & pred_mask & gt_mask)).sum().item()
|
58 |
+
false_negatives = (~pred_mask & gt_mask).sum().item()
|
59 |
+
|
60 |
+
# Calculate precision, recall, and F1 score
|
61 |
+
accuracy = (true_positives + true_negatives) / (gt.size()[1] * gt.size()[0])
|
62 |
+
precision = true_positives / (true_positives + false_positives + 1e-10)
|
63 |
+
recall = true_positives / (true_positives + false_negatives + 1e-10)
|
64 |
+
f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
|
65 |
+
|
66 |
+
return accuracy, precision, recall, f1
|
67 |
+
|
68 |
+
|
69 |
+
def calculate_metrics_l_with_mask(gt, pred, mask, conf_th=0.1, dist_th=5):
|
70 |
+
|
71 |
+
#only works with batch 1. Should be adapted to batch > 1 in an organic way or just do a loop over batch
|
72 |
+
|
73 |
+
geometry_mask = (mask[:, :-1] > 0).cpu()
|
74 |
+
|
75 |
+
pred = pred[geometry_mask]
|
76 |
+
gt = gt[geometry_mask]
|
77 |
+
|
78 |
+
pred_pos = pred[:, :, :-1]
|
79 |
+
gt_pos = gt[:, :, :-1]
|
80 |
+
|
81 |
+
pred_mask = torch.all((pred[:, :, -1] > conf_th), dim=-1)
|
82 |
+
gt_mask = torch.all((gt[:, :, -1] > conf_th), dim=-1)
|
83 |
+
|
84 |
+
gt_flip = torch.flip(gt_pos, dims=[1])
|
85 |
+
|
86 |
+
distances1 = torch.norm(pred_pos - gt_pos, dim=-1)
|
87 |
+
distances2 = torch.norm(pred_pos - gt_flip, dim=-1)
|
88 |
+
|
89 |
+
distances1_bool = torch.all((distances1 < dist_th), dim=-1)
|
90 |
+
distances2_bool = torch.all((distances2 < dist_th), dim=-1)
|
91 |
+
|
92 |
+
# Count true positives, false positives, and false negatives based on distance threshold
|
93 |
+
true_positives = ((distances1_bool | distances2_bool) & pred_mask & gt_mask).sum().item()
|
94 |
+
true_negatives = (~pred_mask & ~gt_mask).sum().item()
|
95 |
+
false_positives = (
|
96 |
+
(pred_mask & ~gt_mask) | ((~distances1_bool & ~distances2_bool) & pred_mask & gt_mask)).sum().item()
|
97 |
+
false_negatives = (~pred_mask & gt_mask).sum().item()
|
98 |
+
|
99 |
+
# Calculate precision, recall, and F1 score
|
100 |
+
accuracy = (true_positives + true_negatives) / geometry_mask.sum().item()
|
101 |
+
precision = true_positives / (true_positives + false_positives + 1e-10)
|
102 |
+
recall = true_positives / (true_positives + false_negatives + 1e-10)
|
103 |
+
f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
|
104 |
+
|
105 |
+
return accuracy, precision, recall, f1
|
106 |
+
|
107 |
+
|
108 |
+
def calc_iou_whole_with_poly(pred_h, gt_h, frame_w=1280, frame_h=720, template_w=115, template_h=74):
|
109 |
+
|
110 |
+
corners = np.array([[0, 0],
|
111 |
+
[frame_w - 1, 0],
|
112 |
+
[frame_w - 1, frame_h - 1],
|
113 |
+
[0, frame_h - 1]], dtype=np.float64)
|
114 |
+
|
115 |
+
mapping_mat = np.linalg.inv(gt_h)
|
116 |
+
mapping_mat /= mapping_mat[2, 2]
|
117 |
+
|
118 |
+
gt_corners = cv2.perspectiveTransform(
|
119 |
+
corners[:, None, :], gt_h) # inv_gt_mat * (gt_mat * [x, y, 1])
|
120 |
+
gt_corners = cv2.perspectiveTransform(
|
121 |
+
gt_corners, np.linalg.inv(gt_h))
|
122 |
+
gt_corners = gt_corners[:, 0, :]
|
123 |
+
|
124 |
+
pred_corners = cv2.perspectiveTransform(
|
125 |
+
corners[:, None, :], gt_h) # inv_pred_mat * (gt_mat * [x, y, 1])
|
126 |
+
pred_corners = cv2.perspectiveTransform(
|
127 |
+
pred_corners, np.linalg.inv(pred_h))
|
128 |
+
pred_corners = pred_corners[:, 0, :]
|
129 |
+
|
130 |
+
gt_poly = Polygon(gt_corners.tolist())
|
131 |
+
pred_poly = Polygon(pred_corners.tolist())
|
132 |
+
|
133 |
+
# f, axarr = plt.subplots(1, 2, figsize=(16, 12))
|
134 |
+
# axarr[0].plot(*gt_poly.exterior.coords.xy)
|
135 |
+
# axarr[1].plot(*pred_poly.exterior.coords.xy)
|
136 |
+
# plt.show()
|
137 |
+
|
138 |
+
if pred_poly.is_valid is False:
|
139 |
+
return 0., None, None
|
140 |
+
|
141 |
+
if not gt_poly.intersects(pred_poly):
|
142 |
+
print('not intersects')
|
143 |
+
iou = 0.
|
144 |
+
else:
|
145 |
+
intersection = gt_poly.intersection(pred_poly).area
|
146 |
+
union = gt_poly.area + pred_poly.area - intersection
|
147 |
+
if union <= 0.:
|
148 |
+
print('whole union', union)
|
149 |
+
iou = 0.
|
150 |
+
else:
|
151 |
+
iou = intersection / union
|
152 |
+
|
153 |
+
return iou, None, None
|
154 |
+
|
155 |
+
def calc_iou_part(pred_h, gt_h, frame_w=1280, frame_h=720, template_w=115, template_h=74):
|
156 |
+
|
157 |
+
# field template binary mask
|
158 |
+
field_mask = np.ones((frame_h, frame_w, 3), dtype=np.uint8) * 255
|
159 |
+
gt_mask = cv2.warpPerspective(field_mask, gt_h, (template_w, template_h),
|
160 |
+
cv2.INTER_AREA, borderMode=cv2.BORDER_CONSTANT, borderValue=(0))
|
161 |
+
|
162 |
+
pred_mask = cv2.warpPerspective(field_mask, pred_h, (template_w, template_h),
|
163 |
+
cv2.INTER_AREA, borderMode=cv2.BORDER_CONSTANT, borderValue=(0))
|
164 |
+
|
165 |
+
gt_mask[gt_mask > 0] = 255
|
166 |
+
pred_mask[pred_mask > 0] = 255
|
167 |
+
|
168 |
+
intersection = ((gt_mask > 0) * (pred_mask > 0)).sum()
|
169 |
+
union = (gt_mask > 0).sum() + (pred_mask > 0).sum() - intersection
|
170 |
+
|
171 |
+
if union <= 0:
|
172 |
+
print('part union', union)
|
173 |
+
# iou = float('nan')
|
174 |
+
iou = 0.
|
175 |
+
else:
|
176 |
+
iou = float(intersection) / float(union)
|
177 |
+
|
178 |
+
# === blending ===
|
179 |
+
gt_white_area = (gt_mask[:, :, 0] == 255) & (
|
180 |
+
gt_mask[:, :, 1] == 255) & (gt_mask[:, :, 2] == 255)
|
181 |
+
gt_fill = gt_mask.copy()
|
182 |
+
gt_fill[gt_white_area, 0] = 255
|
183 |
+
gt_fill[gt_white_area, 1] = 0
|
184 |
+
gt_fill[gt_white_area, 2] = 0
|
185 |
+
pred_white_area = (pred_mask[:, :, 0] == 255) & (
|
186 |
+
pred_mask[:, :, 1] == 255) & (pred_mask[:, :, 2] == 255)
|
187 |
+
pred_fill = pred_mask.copy()
|
188 |
+
pred_fill[pred_white_area, 0] = 0
|
189 |
+
pred_fill[pred_white_area, 1] = 255
|
190 |
+
pred_fill[pred_white_area, 2] = 0
|
191 |
+
gt_maskf = gt_fill.astype(float) / 255
|
192 |
+
pred_maskf = pred_fill.astype(float) / 255
|
193 |
+
fill_resultf = cv2.addWeighted(gt_maskf, 0.5,
|
194 |
+
pred_maskf, 0.5, 0.0)
|
195 |
+
fill_result = np.uint8(fill_resultf * 255)
|
196 |
+
|
197 |
+
return iou
|
198 |
+
|
199 |
+
def calc_proj_error(pred_h, gt_h, frame_w=1280, frame_h=720, template_w=115, template_h=74):
|
200 |
+
|
201 |
+
field_mask = np.ones((template_h, template_w, 3), dtype=np.uint8) * 255
|
202 |
+
gt_mask = cv2.warpPerspective(field_mask, np.linalg.inv(
|
203 |
+
gt_h), (frame_w, frame_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(0))
|
204 |
+
gt_gray = cv2.cvtColor(gt_mask, cv2.COLOR_BGR2GRAY)
|
205 |
+
contours, hierarchy = cv2.findContours(
|
206 |
+
gt_gray, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
207 |
+
contour = np.squeeze(contours[0])
|
208 |
+
poly = Polygon(contour)
|
209 |
+
sample_pts = []
|
210 |
+
num_pts = 2500
|
211 |
+
while len(sample_pts) <= num_pts:
|
212 |
+
x = random.sample(range(0, frame_w), 1)
|
213 |
+
y = random.sample(range(0, frame_h), 1)
|
214 |
+
p = Point(x[0], y[0])
|
215 |
+
if p.within(poly):
|
216 |
+
sample_pts.append([x[0], y[0]])
|
217 |
+
sample_pts = np.array(sample_pts, dtype=np.float32)
|
218 |
+
|
219 |
+
field_dim_x, field_dim_y = 100, 60
|
220 |
+
x_scale = field_dim_x / template_w
|
221 |
+
y_scale = field_dim_y / template_h
|
222 |
+
scaling_mat = np.eye(3)
|
223 |
+
scaling_mat[0, 0] = x_scale
|
224 |
+
scaling_mat[1, 1] = y_scale
|
225 |
+
gt_temp_grid = cv2.perspectiveTransform(
|
226 |
+
sample_pts.reshape(-1, 1, 2), scaling_mat @ gt_h)
|
227 |
+
gt_temp_grid = gt_temp_grid.reshape(-1, 2)
|
228 |
+
pred_temp_grid = cv2.perspectiveTransform(
|
229 |
+
sample_pts.reshape(-1, 1, 2), scaling_mat @ pred_h)
|
230 |
+
pred_temp_grid = pred_temp_grid.reshape(-1, 2)
|
231 |
+
|
232 |
+
# TODO compute distance in top view
|
233 |
+
gt_grid_list = []
|
234 |
+
pred_grid_list = []
|
235 |
+
for gt_pts, pred_pts in zip(gt_temp_grid, pred_temp_grid):
|
236 |
+
if 0 <= gt_pts[0] < field_dim_x and 0 <= gt_pts[1] < field_dim_y and \
|
237 |
+
0 <= pred_pts[0] < field_dim_x and 0 <= pred_pts[1] < field_dim_y:
|
238 |
+
gt_grid_list.append(gt_pts)
|
239 |
+
pred_grid_list.append(pred_pts)
|
240 |
+
gt_grid_list = np.array(gt_grid_list)
|
241 |
+
pred_grid_list = np.array(pred_grid_list)
|
242 |
+
|
243 |
+
if gt_grid_list.shape != pred_grid_list.shape:
|
244 |
+
print('proj error:', gt_grid_list.shape, pred_grid_list.shape)
|
245 |
+
assert gt_grid_list.shape == pred_grid_list.shape, 'shape mismatch'
|
246 |
+
|
247 |
+
if gt_grid_list.size != 0 and pred_grid_list.size != 0:
|
248 |
+
distance_list = calc_euclidean_distance(
|
249 |
+
gt_grid_list, pred_grid_list, axis=1)
|
250 |
+
return distance_list.mean() # average all keypoints
|
251 |
+
else:
|
252 |
+
print(gt_grid_list)
|
253 |
+
print(pred_grid_list)
|
254 |
+
return float('nan')
|
255 |
+
|
256 |
+
def calc_euclidean_distance(a, b, _norm=np.linalg.norm, axis=None):
|
257 |
+
return _norm(a - b, axis=axis)
|
258 |
+
|
259 |
+
def gen_template_grid():
|
260 |
+
# === set uniform grid ===
|
261 |
+
# field_dim_x, field_dim_y = 105.000552, 68.003928 # in meter
|
262 |
+
field_dim_x, field_dim_y = 114.83, 74.37 # in yard
|
263 |
+
# field_dim_x, field_dim_y = 115, 74 # in yard
|
264 |
+
nx, ny = (13, 7)
|
265 |
+
x = np.linspace(0, field_dim_x, nx)
|
266 |
+
y = np.linspace(0, field_dim_y, ny)
|
267 |
+
xv, yv = np.meshgrid(x, y, indexing='ij')
|
268 |
+
uniform_grid = np.stack((xv, yv), axis=2).reshape(-1, 2)
|
269 |
+
uniform_grid = np.concatenate((uniform_grid, np.ones(
|
270 |
+
(uniform_grid.shape[0], 1))), axis=1) # top2bottom, left2right
|
271 |
+
# TODO: class label in template, each keypoints is (x, y, c), c is label that starts from 1
|
272 |
+
for idx, pts in enumerate(uniform_grid):
|
273 |
+
pts[2] = idx + 1 # keypoints label
|
274 |
+
return uniform_grid
|
275 |
+
|
276 |
+
def calc_reproj_error(pred_h, gt_h, frame_w=1280, frame_h=720, template_w=115, template_h=74):
|
277 |
+
|
278 |
+
uniform_grid = gen_template_grid() # grid shape (91, 3), (x, y, label)
|
279 |
+
template_grid = uniform_grid[:, :2].copy()
|
280 |
+
template_grid = template_grid.reshape(-1, 1, 2)
|
281 |
+
|
282 |
+
gt_warp_grid = cv2.perspectiveTransform(template_grid, np.linalg.inv(gt_h))
|
283 |
+
gt_warp_grid = gt_warp_grid.reshape(-1, 2)
|
284 |
+
pred_warp_grid = cv2.perspectiveTransform(
|
285 |
+
template_grid, np.linalg.inv(pred_h))
|
286 |
+
pred_warp_grid = pred_warp_grid.reshape(-1, 2)
|
287 |
+
|
288 |
+
# TODO compute distance in camera view
|
289 |
+
gt_grid_list = []
|
290 |
+
pred_grid_list = []
|
291 |
+
for gt_pts, pred_pts in zip(gt_warp_grid, pred_warp_grid):
|
292 |
+
if 0 <= gt_pts[0] < frame_w and 0 <= gt_pts[1] < frame_h and \
|
293 |
+
0 <= pred_pts[0] < frame_w and 0 <= pred_pts[1] < frame_h:
|
294 |
+
gt_grid_list.append(gt_pts)
|
295 |
+
pred_grid_list.append(pred_pts)
|
296 |
+
gt_grid_list = np.array(gt_grid_list)
|
297 |
+
pred_grid_list = np.array(pred_grid_list)
|
298 |
+
|
299 |
+
if gt_grid_list.shape != pred_grid_list.shape:
|
300 |
+
print('reproj error:', gt_grid_list.shape, pred_grid_list.shape)
|
301 |
+
assert gt_grid_list.shape == pred_grid_list.shape, 'shape mismatch'
|
302 |
+
|
303 |
+
if gt_grid_list.size != 0 and pred_grid_list.size != 0:
|
304 |
+
distance_list = calc_euclidean_distance(
|
305 |
+
gt_grid_list, pred_grid_list, axis=1)
|
306 |
+
distance_list /= frame_h # normalize by image height
|
307 |
+
return distance_list.mean() # average all keypoints
|
308 |
+
else:
|
309 |
+
print(gt_grid_list)
|
310 |
+
print(pred_grid_list)
|
311 |
+
return float('nan')
|
312 |
+
|
model/transforms.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import numbers
|
4 |
+
import warnings
|
5 |
+
import numpy as np
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torchvision.transforms.functional as f
|
8 |
+
import torchvision.transforms as T
|
9 |
+
import torchvision.transforms.v2 as v2
|
10 |
+
|
11 |
+
from torchvision.transforms.functional import _interpolation_modes_from_int, InterpolationMode
|
12 |
+
from torchvision import transforms as _transforms
|
13 |
+
from typing import List, Optional, Tuple, Union
|
14 |
+
from scipy import ndimage
|
15 |
+
from torch import Tensor
|
16 |
+
|
17 |
+
from sn_calibration.src.evaluate_extremities import mirror_labels
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
class ToTensor(torch.nn.Module):
|
22 |
+
def __call__(self, sample):
|
23 |
+
image = sample['image']
|
24 |
+
|
25 |
+
|
26 |
+
return {'image': f.to_tensor(image).float(),
|
27 |
+
'data': sample['data']}
|
28 |
+
|
29 |
+
def __repr__(self) -> str:
|
30 |
+
return f"{self.__class__.__name__}()"
|
31 |
+
|
32 |
+
|
33 |
+
class Normalize(torch.nn.Module):
|
34 |
+
def __init__(self, mean, std):
|
35 |
+
super().__init__()
|
36 |
+
self.mean = mean
|
37 |
+
self.std = std
|
38 |
+
|
39 |
+
def forward(self, sample):
|
40 |
+
image = sample['image']
|
41 |
+
image = f.normalize(image, self.mean, self.std)
|
42 |
+
|
43 |
+
return {'image': image,
|
44 |
+
'data': sample['data']}
|
45 |
+
|
46 |
+
|
47 |
+
def __repr__(self) -> str:
|
48 |
+
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
|
49 |
+
|
50 |
+
|
51 |
+
FLIP_POSTS = {
|
52 |
+
'Goal left post right': 'Goal left post left ',
|
53 |
+
'Goal left post left ': 'Goal left post right',
|
54 |
+
'Goal right post right': 'Goal right post left',
|
55 |
+
'Goal right post left': 'Goal right post right'
|
56 |
+
}
|
57 |
+
|
58 |
+
h_lines = ['Goal left crossbar', 'Side line left', 'Small rect. left main', 'Big rect. left main', 'Middle line',
|
59 |
+
'Big rect. right main', 'Small rect. right main', 'Side line right', 'Goal right crossbar']
|
60 |
+
v_lines = ['Side line top', 'Big rect. left top', 'Small rect. left top', 'Small rect. left bottom',
|
61 |
+
'Big rect. left bottom', 'Big rect. right top', 'Small rect. right top', 'Small rect. right bottom',
|
62 |
+
'Big rect. right bottom', 'Side line bottom']
|
63 |
+
|
64 |
+
def swap_top_bottom_names(line_name: str) -> str:
|
65 |
+
x: str = 'top'
|
66 |
+
y: str = 'bottom'
|
67 |
+
if x in line_name or y in line_name:
|
68 |
+
return y.join(part.replace(y, x) for part in line_name.split(x))
|
69 |
+
return line_name
|
70 |
+
|
71 |
+
|
72 |
+
def swap_posts_names(line_name: str) -> str:
|
73 |
+
if line_name in FLIP_POSTS:
|
74 |
+
return FLIP_POSTS[line_name]
|
75 |
+
return line_name
|
76 |
+
|
77 |
+
|
78 |
+
def flip_annot_names(annot, swap_top_bottom: bool = True,
|
79 |
+
swap_posts: bool = True):
|
80 |
+
annot = mirror_labels(annot)
|
81 |
+
if swap_top_bottom:
|
82 |
+
annot = {swap_top_bottom_names(k): v for k, v in annot.items()}
|
83 |
+
if swap_posts:
|
84 |
+
annot = {swap_posts_names(k): v for k, v in annot.items()}
|
85 |
+
return annot
|
86 |
+
|
87 |
+
|
88 |
+
class RandomHorizontalFlip(torch.nn.Module):
|
89 |
+
def __init__(self, p=0.5):
|
90 |
+
super().__init__()
|
91 |
+
self.p = p
|
92 |
+
|
93 |
+
def forward(self, sample):
|
94 |
+
if torch.rand(1) < self.p:
|
95 |
+
image, data = sample['image'], sample['data']
|
96 |
+
image = f.hflip(image)
|
97 |
+
data = flip_annot_names(data)
|
98 |
+
for line in data:
|
99 |
+
for point in data[line]:
|
100 |
+
point['x'] = 1.0 - point['x']
|
101 |
+
|
102 |
+
return {'image': image,
|
103 |
+
'data': data}
|
104 |
+
else:
|
105 |
+
return {'image': sample['image'],
|
106 |
+
'data': sample['data']}
|
107 |
+
|
108 |
+
def __repr__(self) -> str:
|
109 |
+
return f"{self.__class__.__name__}(p={self.p})"
|
110 |
+
|
111 |
+
|
112 |
+
class LRAmbiguityFix(torch.nn.Module):
|
113 |
+
def __init__(self, v_th, h_th):
|
114 |
+
super().__init__()
|
115 |
+
self.v_th = v_th
|
116 |
+
self.h_th = h_th
|
117 |
+
|
118 |
+
def forward(self, sample):
|
119 |
+
data = sample['data']
|
120 |
+
|
121 |
+
if len(data) == 0:
|
122 |
+
return {'image': sample['image'],
|
123 |
+
'data': sample['data']}
|
124 |
+
|
125 |
+
n_left, n_right = self.compute_n_sides(data)
|
126 |
+
|
127 |
+
angles_v, angles_h = [], []
|
128 |
+
for line in data.keys():
|
129 |
+
line_points = []
|
130 |
+
for point in data[line]:
|
131 |
+
line_points.append((point['x'], point['y']))
|
132 |
+
|
133 |
+
sorted_points = sorted(line_points, key=lambda point: (point[0], point[1]))
|
134 |
+
pi, pf = sorted_points[0], sorted_points[-1]
|
135 |
+
if line in h_lines:
|
136 |
+
angle_h = self.calculate_angle_h(pi[0], pi[1], pf[0], pf[1])
|
137 |
+
if angle_h:
|
138 |
+
angles_h.append(abs(angle_h))
|
139 |
+
if line in v_lines:
|
140 |
+
angle_v = self.calculate_angle_v(pi[0], pi[1], pf[0], pf[1])
|
141 |
+
if angle_v:
|
142 |
+
angles_v.append(abs(angle_v))
|
143 |
+
|
144 |
+
|
145 |
+
if len(angles_h) > 0 and len(angles_v) > 0:
|
146 |
+
if np.mean(angles_h) < self.h_th and np.mean(angles_v) < self.v_th:
|
147 |
+
if n_right > n_left:
|
148 |
+
data = flip_annot_names(data, swap_top_bottom=False, swap_posts=False)
|
149 |
+
|
150 |
+
return {'image': sample['image'],
|
151 |
+
'data': data}
|
152 |
+
|
153 |
+
def calculate_angle_h(self, x1, y1, x2, y2):
|
154 |
+
if not x2 - x1 == 0:
|
155 |
+
slope = (y2 - y1) / (x2 - x1)
|
156 |
+
angle = math.atan(slope)
|
157 |
+
angle_degrees = math.degrees(angle)
|
158 |
+
return angle_degrees
|
159 |
+
else:
|
160 |
+
return None
|
161 |
+
def calculate_angle_v(self, x1, y1, x2, y2):
|
162 |
+
if not x2 - x1 == 0:
|
163 |
+
slope = (y2 - y1) / (x2 - x1)
|
164 |
+
angle = math.atan(1 / slope) if slope != 0 else math.pi / 2 # Avoid division by zero
|
165 |
+
angle_degrees = math.degrees(angle)
|
166 |
+
return angle_degrees
|
167 |
+
else:
|
168 |
+
return None
|
169 |
+
|
170 |
+
def compute_n_sides(self, data):
|
171 |
+
n_left, n_right = 0, 0
|
172 |
+
for line in data:
|
173 |
+
line_words = line.split()[:3]
|
174 |
+
if 'left' in line_words:
|
175 |
+
n_left += 1
|
176 |
+
elif 'right' in line_words:
|
177 |
+
n_right += 1
|
178 |
+
return n_left, n_right
|
179 |
+
|
180 |
+
def __repr__(self) -> str:
|
181 |
+
return f"{self.__class__.__name__}(v_th={self.v_th}, h_th={self.h_th})"
|
182 |
+
|
183 |
+
|
184 |
+
class AddGaussianNoise(torch.nn.Module):
|
185 |
+
def __init__(self, mean=0., std=2.):
|
186 |
+
self.std = std
|
187 |
+
self.mean = mean
|
188 |
+
|
189 |
+
def __call__(self, sample):
|
190 |
+
image = sample['image']
|
191 |
+
image += torch.randn(image.size()) * self.std + self.mean
|
192 |
+
image = torch.clip(image, 0, 1)
|
193 |
+
|
194 |
+
return {'image': image,
|
195 |
+
'data': sample['data']}
|
196 |
+
|
197 |
+
def __repr__(self):
|
198 |
+
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
|
199 |
+
|
200 |
+
|
201 |
+
class ColorJitter(torch.nn.Module):
|
202 |
+
|
203 |
+
def __init__(
|
204 |
+
self,
|
205 |
+
brightness: Union[float, Tuple[float, float]] = 0,
|
206 |
+
contrast: Union[float, Tuple[float, float]] = 0,
|
207 |
+
saturation: Union[float, Tuple[float, float]] = 0,
|
208 |
+
hue: Union[float, Tuple[float, float]] = 0,
|
209 |
+
) -> None:
|
210 |
+
super().__init__()
|
211 |
+
self.brightness = self._check_input(brightness, "brightness")
|
212 |
+
self.contrast = self._check_input(contrast, "contrast")
|
213 |
+
self.saturation = self._check_input(saturation, "saturation")
|
214 |
+
self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
|
215 |
+
|
216 |
+
@torch.jit.unused
|
217 |
+
def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
|
218 |
+
if isinstance(value, numbers.Number):
|
219 |
+
if value < 0:
|
220 |
+
raise ValueError(f"If {name} is a single number, it must be non negative.")
|
221 |
+
value = [center - float(value), center + float(value)]
|
222 |
+
if clip_first_on_zero:
|
223 |
+
value[0] = max(value[0], 0.0)
|
224 |
+
elif isinstance(value, (tuple, list)) and len(value) == 2:
|
225 |
+
value = [float(value[0]), float(value[1])]
|
226 |
+
else:
|
227 |
+
raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
|
228 |
+
|
229 |
+
if not bound[0] <= value[0] <= value[1] <= bound[1]:
|
230 |
+
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
|
231 |
+
|
232 |
+
# if value is 0 or (1., 1.) for brightness/contrast/saturation
|
233 |
+
# or (0., 0.) for hue, do nothing
|
234 |
+
if value[0] == value[1] == center:
|
235 |
+
return None
|
236 |
+
else:
|
237 |
+
return tuple(value)
|
238 |
+
|
239 |
+
@staticmethod
|
240 |
+
def get_params(
|
241 |
+
brightness: Optional[List[float]],
|
242 |
+
contrast: Optional[List[float]],
|
243 |
+
saturation: Optional[List[float]],
|
244 |
+
hue: Optional[List[float]],
|
245 |
+
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
|
246 |
+
"""Get the parameters for the randomized transform to be applied on image.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
|
250 |
+
uniformly. Pass None to turn off the transformation.
|
251 |
+
contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
|
252 |
+
uniformly. Pass None to turn off the transformation.
|
253 |
+
saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
|
254 |
+
uniformly. Pass None to turn off the transformation.
|
255 |
+
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
|
256 |
+
Pass None to turn off the transformation.
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
tuple: The parameters used to apply the randomized transform
|
260 |
+
along with their random order.
|
261 |
+
"""
|
262 |
+
fn_idx = torch.randperm(4)
|
263 |
+
|
264 |
+
b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
|
265 |
+
c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
|
266 |
+
s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
|
267 |
+
h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
|
268 |
+
|
269 |
+
return fn_idx, b, c, s, h
|
270 |
+
|
271 |
+
|
272 |
+
def forward(self, sample):
|
273 |
+
"""
|
274 |
+
Args:
|
275 |
+
img (PIL Image or Tensor): Input image.
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
PIL Image or Tensor: Color jittered image.
|
279 |
+
"""
|
280 |
+
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
|
281 |
+
self.brightness, self.contrast, self.saturation, self.hue
|
282 |
+
)
|
283 |
+
|
284 |
+
image = sample['image']
|
285 |
+
|
286 |
+
for fn_id in fn_idx:
|
287 |
+
if fn_id == 0 and brightness_factor is not None:
|
288 |
+
image = f.adjust_brightness(image, brightness_factor)
|
289 |
+
elif fn_id == 1 and contrast_factor is not None:
|
290 |
+
image = f.adjust_contrast(image, contrast_factor)
|
291 |
+
elif fn_id == 2 and saturation_factor is not None:
|
292 |
+
image = f.adjust_saturation(image, saturation_factor)
|
293 |
+
elif fn_id == 3 and hue_factor is not None:
|
294 |
+
image = f.adjust_hue(image, hue_factor)
|
295 |
+
|
296 |
+
return {'image': image,
|
297 |
+
'data': sample['data']}
|
298 |
+
|
299 |
+
|
300 |
+
def __repr__(self) -> str:
|
301 |
+
s = (
|
302 |
+
f"{self.__class__.__name__}("
|
303 |
+
f"brightness={self.brightness}"
|
304 |
+
f", contrast={self.contrast}"
|
305 |
+
f", saturation={self.saturation}"
|
306 |
+
f", hue={self.hue})"
|
307 |
+
)
|
308 |
+
return s
|
309 |
+
|
310 |
+
|
311 |
+
class Resize(torch.nn.Module):
|
312 |
+
def __init__(self, size, interpolation=InterpolationMode.BILINEAR):
|
313 |
+
super().__init__()
|
314 |
+
self.size = size
|
315 |
+
|
316 |
+
# Backward compatibility with integer value
|
317 |
+
if isinstance(interpolation, int):
|
318 |
+
warnings.warn(
|
319 |
+
"Argument interpolation should be of type InterpolationMode instead of int. "
|
320 |
+
"Please, use InterpolationMode enum."
|
321 |
+
)
|
322 |
+
interpolation = _interpolation_modes_from_int(interpolation)
|
323 |
+
|
324 |
+
self.interpolation = interpolation
|
325 |
+
|
326 |
+
def forward(self, sample):
|
327 |
+
image = sample["image"]
|
328 |
+
image = f.resize(image, self.size, self.interpolation)
|
329 |
+
|
330 |
+
return {'image': image,
|
331 |
+
'data': sample['data']}
|
332 |
+
|
333 |
+
|
334 |
+
def __repr__(self):
|
335 |
+
interpolate_str = self.interpolation.value
|
336 |
+
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
|
337 |
+
|
338 |
+
|
339 |
+
|
340 |
+
transforms = v2.Compose([
|
341 |
+
ToTensor(),
|
342 |
+
RandomHorizontalFlip(p=.5),
|
343 |
+
ColorJitter(brightness=(0.05), contrast=(0.05), saturation=(0.05), hue=(0.05)),
|
344 |
+
AddGaussianNoise(0, .1)
|
345 |
+
])
|
346 |
+
|
347 |
+
transforms_w_LR = v2.Compose([
|
348 |
+
ToTensor(),
|
349 |
+
RandomHorizontalFlip(p=.5),
|
350 |
+
LRAmbiguityFix(v_th=70, h_th=20),
|
351 |
+
ColorJitter(brightness=(0.05), contrast=(0.05), saturation=(0.05), hue=(0.05)),
|
352 |
+
AddGaussianNoise(0, .1)
|
353 |
+
])
|
354 |
+
|
355 |
+
no_transforms = v2.Compose([
|
356 |
+
ToTensor()
|
357 |
+
])
|
358 |
+
|
359 |
+
no_transforms_w_LR = v2.Compose([
|
360 |
+
ToTensor(),
|
361 |
+
LRAmbiguityFix(v_th=70, h_th=20)
|
362 |
+
])
|
363 |
+
|
364 |
+
|
model/transformsWC.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import numbers
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torchvision.transforms.functional as f
|
6 |
+
import torchvision.transforms as T
|
7 |
+
import torchvision.transforms.v2 as v2
|
8 |
+
|
9 |
+
from torchvision import transforms as _transforms
|
10 |
+
from typing import List, Optional, Tuple, Union
|
11 |
+
from scipy import ndimage
|
12 |
+
from torch import Tensor
|
13 |
+
|
14 |
+
from sn_calibration.src.evaluate_extremities import mirror_labels
|
15 |
+
|
16 |
+
class ToTensor(torch.nn.Module):
|
17 |
+
def __call__(self, sample):
|
18 |
+
image, target, mask = sample['image'], sample['target'], sample['mask']
|
19 |
+
|
20 |
+
return {'image': f.to_tensor(image).float(),
|
21 |
+
'target': torch.from_numpy(target).float(),
|
22 |
+
'mask': torch.from_numpy(mask).float()}
|
23 |
+
|
24 |
+
def __repr__(self) -> str:
|
25 |
+
return f"{self.__class__.__name__}()"
|
26 |
+
|
27 |
+
class RandomHorizontalFlip(torch.nn.Module):
|
28 |
+
def __init__(self, p=0.5):
|
29 |
+
super().__init__()
|
30 |
+
self.p = p
|
31 |
+
self.swap_dict = {1:3, 2:2, 3:1, 4:7, 5:6, 6:5, 7:4, 8:11, 9:10, 10:9, 11:8, 12:15, 13:14, 14:13, 15:12,
|
32 |
+
16:19, 17:18, 18:17, 19:16, 20:23, 21:22, 22:21, 23:20, 24:27, 25:26, 26:25, 27:24, 28:30,
|
33 |
+
29:29, 30:28, 31:33, 32:32, 33:31, 34:36, 35:35, 36:34, 37:40, 38:39, 39:38, 40:37, 41:44,
|
34 |
+
42:43, 43:42, 44:41, 45:57, 46:56, 47:55, 48:49, 49:48, 50:52, 51:51, 52:50, 53:54, 54:53,
|
35 |
+
55:47, 56:46, 57:45, 58:58}
|
36 |
+
|
37 |
+
|
38 |
+
def forward(self, sample):
|
39 |
+
if torch.rand(1) < self.p:
|
40 |
+
image, target, mask = sample['image'], sample['target'], sample['mask']
|
41 |
+
image = f.hflip(image)
|
42 |
+
target = f.hflip(target)
|
43 |
+
|
44 |
+
target_swap, mask_swap = self.swap_layers(target, mask)
|
45 |
+
|
46 |
+
return {'image': image,
|
47 |
+
'target': target_swap,
|
48 |
+
'mask': mask_swap}
|
49 |
+
else:
|
50 |
+
return {'image': sample['image'],
|
51 |
+
'target': sample['target'],
|
52 |
+
'mask': sample['mask']}
|
53 |
+
|
54 |
+
|
55 |
+
def swap_layers(self, target, mask):
|
56 |
+
target_swap = torch.zeros_like(target)
|
57 |
+
mask_swap = torch.zeros_like(mask)
|
58 |
+
for kp in self.swap_dict.keys():
|
59 |
+
kp_swap = self.swap_dict[kp]
|
60 |
+
target_swap[kp_swap-1, :, :] = target[kp-1, :, :].clone()
|
61 |
+
mask_swap[kp_swap-1] = mask[kp-1].clone()
|
62 |
+
|
63 |
+
return target_swap, mask_swap
|
64 |
+
|
65 |
+
|
66 |
+
def __repr__(self) -> str:
|
67 |
+
return f"{self.__class__.__name__}(p={self.p})"
|
68 |
+
|
69 |
+
|
70 |
+
class AddGaussianNoise(torch.nn.Module):
|
71 |
+
def __init__(self, mean=0., std=2.):
|
72 |
+
self.std = std
|
73 |
+
self.mean = mean
|
74 |
+
|
75 |
+
def __call__(self, sample):
|
76 |
+
image = sample['image']
|
77 |
+
image += torch.randn(image.size()) * self.std + self.mean
|
78 |
+
image = torch.clip(image, 0, 1)
|
79 |
+
|
80 |
+
return {'image': image,
|
81 |
+
'target': sample['target'],
|
82 |
+
'mask': sample['mask']}
|
83 |
+
|
84 |
+
def __repr__(self):
|
85 |
+
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
|
86 |
+
|
87 |
+
|
88 |
+
class ColorJitter(torch.nn.Module):
|
89 |
+
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
brightness: Union[float, Tuple[float, float]] = 0,
|
93 |
+
contrast: Union[float, Tuple[float, float]] = 0,
|
94 |
+
saturation: Union[float, Tuple[float, float]] = 0,
|
95 |
+
hue: Union[float, Tuple[float, float]] = 0,
|
96 |
+
) -> None:
|
97 |
+
super().__init__()
|
98 |
+
self.brightness = self._check_input(brightness, "brightness")
|
99 |
+
self.contrast = self._check_input(contrast, "contrast")
|
100 |
+
self.saturation = self._check_input(saturation, "saturation")
|
101 |
+
self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
|
102 |
+
|
103 |
+
@torch.jit.unused
|
104 |
+
def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
|
105 |
+
if isinstance(value, numbers.Number):
|
106 |
+
if value < 0:
|
107 |
+
raise ValueError(f"If {name} is a single number, it must be non negative.")
|
108 |
+
value = [center - float(value), center + float(value)]
|
109 |
+
if clip_first_on_zero:
|
110 |
+
value[0] = max(value[0], 0.0)
|
111 |
+
elif isinstance(value, (tuple, list)) and len(value) == 2:
|
112 |
+
value = [float(value[0]), float(value[1])]
|
113 |
+
else:
|
114 |
+
raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
|
115 |
+
|
116 |
+
if not bound[0] <= value[0] <= value[1] <= bound[1]:
|
117 |
+
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
|
118 |
+
|
119 |
+
# if value is 0 or (1., 1.) for brightness/contrast/saturation
|
120 |
+
# or (0., 0.) for hue, do nothing
|
121 |
+
if value[0] == value[1] == center:
|
122 |
+
return None
|
123 |
+
else:
|
124 |
+
return tuple(value)
|
125 |
+
|
126 |
+
@staticmethod
|
127 |
+
def get_params(
|
128 |
+
brightness: Optional[List[float]],
|
129 |
+
contrast: Optional[List[float]],
|
130 |
+
saturation: Optional[List[float]],
|
131 |
+
hue: Optional[List[float]],
|
132 |
+
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
|
133 |
+
"""Get the parameters for the randomized transform to be applied on image.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
|
137 |
+
uniformly. Pass None to turn off the transformation.
|
138 |
+
contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
|
139 |
+
uniformly. Pass None to turn off the transformation.
|
140 |
+
saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
|
141 |
+
uniformly. Pass None to turn off the transformation.
|
142 |
+
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
|
143 |
+
Pass None to turn off the transformation.
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
tuple: The parameters used to apply the randomized transform
|
147 |
+
along with their random order.
|
148 |
+
"""
|
149 |
+
fn_idx = torch.randperm(4)
|
150 |
+
|
151 |
+
b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
|
152 |
+
c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
|
153 |
+
s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
|
154 |
+
h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
|
155 |
+
|
156 |
+
return fn_idx, b, c, s, h
|
157 |
+
|
158 |
+
|
159 |
+
def forward(self, sample):
|
160 |
+
"""
|
161 |
+
Args:
|
162 |
+
img (PIL Image or Tensor): Input image.
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
PIL Image or Tensor: Color jittered image.
|
166 |
+
"""
|
167 |
+
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
|
168 |
+
self.brightness, self.contrast, self.saturation, self.hue
|
169 |
+
)
|
170 |
+
|
171 |
+
image = sample['image']
|
172 |
+
|
173 |
+
for fn_id in fn_idx:
|
174 |
+
if fn_id == 0 and brightness_factor is not None:
|
175 |
+
image = f.adjust_brightness(image, brightness_factor)
|
176 |
+
elif fn_id == 1 and contrast_factor is not None:
|
177 |
+
image = f.adjust_contrast(image, contrast_factor)
|
178 |
+
elif fn_id == 2 and saturation_factor is not None:
|
179 |
+
image = f.adjust_saturation(image, saturation_factor)
|
180 |
+
elif fn_id == 3 and hue_factor is not None:
|
181 |
+
image = f.adjust_hue(image, hue_factor)
|
182 |
+
|
183 |
+
return {'image': image,
|
184 |
+
'target': sample['target'],
|
185 |
+
'mask': sample['mask']}
|
186 |
+
|
187 |
+
|
188 |
+
def __repr__(self) -> str:
|
189 |
+
s = (
|
190 |
+
f"{self.__class__.__name__}("
|
191 |
+
f"brightness={self.brightness}"
|
192 |
+
f", contrast={self.contrast}"
|
193 |
+
f", saturation={self.saturation}"
|
194 |
+
f", hue={self.hue})"
|
195 |
+
)
|
196 |
+
return s
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
transforms = v2.Compose([
|
201 |
+
ToTensor(),
|
202 |
+
RandomHorizontalFlip(p=.5),
|
203 |
+
ColorJitter(brightness=(0.05), contrast=(0.05), saturation=(0.05), hue=(0.05)),
|
204 |
+
AddGaussianNoise(0, .1)
|
205 |
+
])
|
206 |
+
|
207 |
+
|
208 |
+
no_transforms = v2.Compose([
|
209 |
+
ToTensor(),
|
210 |
+
])
|
211 |
+
|
212 |
+
|
model/transformsWC_l.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import numbers
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torchvision.transforms.functional as f
|
6 |
+
import torchvision.transforms as T
|
7 |
+
import torchvision.transforms.v2 as v2
|
8 |
+
|
9 |
+
from torchvision import transforms as _transforms
|
10 |
+
from typing import List, Optional, Tuple, Union
|
11 |
+
from scipy import ndimage
|
12 |
+
from torch import Tensor
|
13 |
+
|
14 |
+
from sn_calibration.src.evaluate_extremities import mirror_labels
|
15 |
+
|
16 |
+
class ToTensor(torch.nn.Module):
|
17 |
+
def __call__(self, sample):
|
18 |
+
image, target, mask = sample['image'], sample['target'], sample['mask']
|
19 |
+
|
20 |
+
return {'image': f.to_tensor(image).float(),
|
21 |
+
'target': torch.from_numpy(target).float(),
|
22 |
+
'mask': torch.from_numpy(mask).float()}
|
23 |
+
|
24 |
+
def __repr__(self) -> str:
|
25 |
+
return f"{self.__class__.__name__}()"
|
26 |
+
|
27 |
+
class RandomHorizontalFlip(torch.nn.Module):
|
28 |
+
def __init__(self, p=0.5):
|
29 |
+
super().__init__()
|
30 |
+
self.p = p
|
31 |
+
self.swap_dict = {1: 4, 2: 5, 3: 6, 4: 1, 5: 2, 6: 3, 7: 10, 8: 12, 9: 11, 10: 7, 11: 9, 12: 8, 13: 13,
|
32 |
+
14: 14, 15: 16, 16: 15, 17: 17, 18: 21, 19: 22, 20: 23, 21: 18, 22: 19, 23: 20, 24:24}
|
33 |
+
|
34 |
+
def forward(self, sample):
|
35 |
+
if torch.rand(1) < self.p:
|
36 |
+
image, target, mask = sample['image'], sample['target'], sample['mask']
|
37 |
+
image = f.hflip(image)
|
38 |
+
target = f.hflip(target)
|
39 |
+
|
40 |
+
target_swap, mask_swap = self.swap_layers(target, mask)
|
41 |
+
|
42 |
+
return {'image': image,
|
43 |
+
'target': target_swap,
|
44 |
+
'mask': mask_swap}
|
45 |
+
else:
|
46 |
+
return {'image': sample['image'],
|
47 |
+
'target': sample['target'],
|
48 |
+
'mask': sample['mask']}
|
49 |
+
|
50 |
+
|
51 |
+
def swap_layers(self, target, mask):
|
52 |
+
target_swap = torch.zeros_like(target)
|
53 |
+
mask_swap = torch.zeros_like(mask)
|
54 |
+
for kp in self.swap_dict.keys():
|
55 |
+
kp_swap = self.swap_dict[kp]
|
56 |
+
target_swap[kp_swap-1, :, :] = target[kp-1, :, :].clone()
|
57 |
+
mask_swap[kp_swap-1] = mask[kp-1].clone()
|
58 |
+
|
59 |
+
return target_swap, mask_swap
|
60 |
+
|
61 |
+
|
62 |
+
def __repr__(self) -> str:
|
63 |
+
return f"{self.__class__.__name__}(p={self.p})"
|
64 |
+
|
65 |
+
|
66 |
+
class AddGaussianNoise(torch.nn.Module):
|
67 |
+
def __init__(self, mean=0., std=2.):
|
68 |
+
self.std = std
|
69 |
+
self.mean = mean
|
70 |
+
|
71 |
+
def __call__(self, sample):
|
72 |
+
image = sample['image']
|
73 |
+
image += torch.randn(image.size()) * self.std + self.mean
|
74 |
+
image = torch.clip(image, 0, 1)
|
75 |
+
|
76 |
+
return {'image': image,
|
77 |
+
'target': sample['target'],
|
78 |
+
'mask': sample['mask']}
|
79 |
+
|
80 |
+
def __repr__(self):
|
81 |
+
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
|
82 |
+
|
83 |
+
|
84 |
+
class ColorJitter(torch.nn.Module):
|
85 |
+
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
brightness: Union[float, Tuple[float, float]] = 0,
|
89 |
+
contrast: Union[float, Tuple[float, float]] = 0,
|
90 |
+
saturation: Union[float, Tuple[float, float]] = 0,
|
91 |
+
hue: Union[float, Tuple[float, float]] = 0,
|
92 |
+
) -> None:
|
93 |
+
super().__init__()
|
94 |
+
self.brightness = self._check_input(brightness, "brightness")
|
95 |
+
self.contrast = self._check_input(contrast, "contrast")
|
96 |
+
self.saturation = self._check_input(saturation, "saturation")
|
97 |
+
self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
|
98 |
+
|
99 |
+
@torch.jit.unused
|
100 |
+
def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
|
101 |
+
if isinstance(value, numbers.Number):
|
102 |
+
if value < 0:
|
103 |
+
raise ValueError(f"If {name} is a single number, it must be non negative.")
|
104 |
+
value = [center - float(value), center + float(value)]
|
105 |
+
if clip_first_on_zero:
|
106 |
+
value[0] = max(value[0], 0.0)
|
107 |
+
elif isinstance(value, (tuple, list)) and len(value) == 2:
|
108 |
+
value = [float(value[0]), float(value[1])]
|
109 |
+
else:
|
110 |
+
raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
|
111 |
+
|
112 |
+
if not bound[0] <= value[0] <= value[1] <= bound[1]:
|
113 |
+
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
|
114 |
+
|
115 |
+
# if value is 0 or (1., 1.) for brightness/contrast/saturation
|
116 |
+
# or (0., 0.) for hue, do nothing
|
117 |
+
if value[0] == value[1] == center:
|
118 |
+
return None
|
119 |
+
else:
|
120 |
+
return tuple(value)
|
121 |
+
|
122 |
+
@staticmethod
|
123 |
+
def get_params(
|
124 |
+
brightness: Optional[List[float]],
|
125 |
+
contrast: Optional[List[float]],
|
126 |
+
saturation: Optional[List[float]],
|
127 |
+
hue: Optional[List[float]],
|
128 |
+
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
|
129 |
+
"""Get the parameters for the randomized transform to be applied on image.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
|
133 |
+
uniformly. Pass None to turn off the transformation.
|
134 |
+
contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
|
135 |
+
uniformly. Pass None to turn off the transformation.
|
136 |
+
saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
|
137 |
+
uniformly. Pass None to turn off the transformation.
|
138 |
+
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
|
139 |
+
Pass None to turn off the transformation.
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
tuple: The parameters used to apply the randomized transform
|
143 |
+
along with their random order.
|
144 |
+
"""
|
145 |
+
fn_idx = torch.randperm(4)
|
146 |
+
|
147 |
+
b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
|
148 |
+
c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
|
149 |
+
s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
|
150 |
+
h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
|
151 |
+
|
152 |
+
return fn_idx, b, c, s, h
|
153 |
+
|
154 |
+
|
155 |
+
def forward(self, sample):
|
156 |
+
"""
|
157 |
+
Args:
|
158 |
+
img (PIL Image or Tensor): Input image.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
PIL Image or Tensor: Color jittered image.
|
162 |
+
"""
|
163 |
+
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
|
164 |
+
self.brightness, self.contrast, self.saturation, self.hue
|
165 |
+
)
|
166 |
+
|
167 |
+
image = sample['image']
|
168 |
+
|
169 |
+
for fn_id in fn_idx:
|
170 |
+
if fn_id == 0 and brightness_factor is not None:
|
171 |
+
image = f.adjust_brightness(image, brightness_factor)
|
172 |
+
elif fn_id == 1 and contrast_factor is not None:
|
173 |
+
image = f.adjust_contrast(image, contrast_factor)
|
174 |
+
elif fn_id == 2 and saturation_factor is not None:
|
175 |
+
image = f.adjust_saturation(image, saturation_factor)
|
176 |
+
elif fn_id == 3 and hue_factor is not None:
|
177 |
+
image = f.adjust_hue(image, hue_factor)
|
178 |
+
|
179 |
+
return {'image': image,
|
180 |
+
'target': sample['target'],
|
181 |
+
'mask': sample['mask']}
|
182 |
+
|
183 |
+
|
184 |
+
def __repr__(self) -> str:
|
185 |
+
s = (
|
186 |
+
f"{self.__class__.__name__}("
|
187 |
+
f"brightness={self.brightness}"
|
188 |
+
f", contrast={self.contrast}"
|
189 |
+
f", saturation={self.saturation}"
|
190 |
+
f", hue={self.hue})"
|
191 |
+
)
|
192 |
+
return s
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
transforms = v2.Compose([
|
197 |
+
ToTensor(),
|
198 |
+
RandomHorizontalFlip(p=.5),
|
199 |
+
ColorJitter(brightness=(0.05), contrast=(0.05), saturation=(0.05), hue=(0.05)),
|
200 |
+
AddGaussianNoise(0, .1)
|
201 |
+
])
|
202 |
+
|
203 |
+
|
204 |
+
no_transforms = v2.Compose([
|
205 |
+
ToTensor(),
|
206 |
+
])
|
207 |
+
|
208 |
+
|
model/transforms_l.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import numbers
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torchvision.transforms.functional as f
|
7 |
+
import torchvision.transforms.v2 as v2
|
8 |
+
|
9 |
+
from torchvision.transforms.functional import _interpolation_modes_from_int, InterpolationMode
|
10 |
+
from torchvision import transforms as _transforms
|
11 |
+
from typing import List, Optional, Tuple, Union
|
12 |
+
from scipy import ndimage
|
13 |
+
from torch import Tensor
|
14 |
+
|
15 |
+
from sn_calibration.src.evaluate_extremities import mirror_labels
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
class ToTensor(torch.nn.Module):
|
20 |
+
def __call__(self, sample):
|
21 |
+
image, target = sample['image'], sample['data']
|
22 |
+
|
23 |
+
|
24 |
+
return {'image': f.to_tensor(image).float(),
|
25 |
+
'data': sample['data']}
|
26 |
+
|
27 |
+
def __repr__(self) -> str:
|
28 |
+
return f"{self.__class__.__name__}()"
|
29 |
+
|
30 |
+
|
31 |
+
class Normalize(torch.nn.Module):
|
32 |
+
def __init__(self, mean, std):
|
33 |
+
super().__init__()
|
34 |
+
self.mean = mean
|
35 |
+
self.std = std
|
36 |
+
|
37 |
+
def forward(self, sample):
|
38 |
+
image = sample['image']
|
39 |
+
image = f.normalize(image, self.mean, self.std)
|
40 |
+
|
41 |
+
return {'image': image,
|
42 |
+
'data': sample['data']}
|
43 |
+
|
44 |
+
|
45 |
+
def __repr__(self) -> str:
|
46 |
+
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
|
47 |
+
|
48 |
+
|
49 |
+
FLIP_POSTS = {
|
50 |
+
'Goal left post right': 'Goal left post left ',
|
51 |
+
'Goal left post left ': 'Goal left post right',
|
52 |
+
'Goal right post right': 'Goal right post left',
|
53 |
+
'Goal right post left': 'Goal right post right'
|
54 |
+
}
|
55 |
+
|
56 |
+
h_lines = ['Goal left crossbar', 'Side line left', 'Small rect. left main', 'Big rect. left main', 'Middle line',
|
57 |
+
'Big rect. right main', 'Small rect. right main', 'Side line right', 'Goal right crossbar']
|
58 |
+
v_lines = ['Side line top', 'Big rect. left top', 'Small rect. left top', 'Small rect. left bottom',
|
59 |
+
'Big rect. left bottom', 'Big rect. right top', 'Small rect. right top', 'Small rect. right bottom',
|
60 |
+
'Big rect. right bottom', 'Side line bottom']
|
61 |
+
|
62 |
+
def swap_top_bottom_names(line_name: str) -> str:
|
63 |
+
x: str = 'top'
|
64 |
+
y: str = 'bottom'
|
65 |
+
if x in line_name or y in line_name:
|
66 |
+
return y.join(part.replace(y, x) for part in line_name.split(x))
|
67 |
+
return line_name
|
68 |
+
|
69 |
+
|
70 |
+
def swap_posts_names(line_name: str) -> str:
|
71 |
+
if line_name in FLIP_POSTS:
|
72 |
+
return FLIP_POSTS[line_name]
|
73 |
+
return line_name
|
74 |
+
|
75 |
+
|
76 |
+
def flip_annot_names(annot, swap_top_bottom: bool = True, swap_posts: bool = True):
|
77 |
+
annot = mirror_labels(annot)
|
78 |
+
if swap_top_bottom:
|
79 |
+
annot = {swap_top_bottom_names(k): v for k, v in annot.items()}
|
80 |
+
if swap_posts:
|
81 |
+
annot = {swap_posts_names(k): v for k, v in annot.items()}
|
82 |
+
return annot
|
83 |
+
|
84 |
+
|
85 |
+
class LRAmbiguityFix(torch.nn.Module):
|
86 |
+
def __init__(self, v_th, h_th):
|
87 |
+
super().__init__()
|
88 |
+
self.v_th = v_th
|
89 |
+
self.h_th = h_th
|
90 |
+
|
91 |
+
def forward(self, sample):
|
92 |
+
data = sample['data']
|
93 |
+
|
94 |
+
if len(data) == 0:
|
95 |
+
return {'image': sample['image'],
|
96 |
+
'data': sample['data']}
|
97 |
+
|
98 |
+
n_left, n_right = self.compute_n_sides(data)
|
99 |
+
|
100 |
+
angles_v, angles_h = [], []
|
101 |
+
for line in data.keys():
|
102 |
+
line_points = []
|
103 |
+
for point in data[line]:
|
104 |
+
line_points.append((point['x'], point['y']))
|
105 |
+
|
106 |
+
sorted_points = sorted(line_points, key=lambda point: (point[0], point[1]))
|
107 |
+
pi, pf = sorted_points[0], sorted_points[-1]
|
108 |
+
if line in h_lines:
|
109 |
+
angle_h = self.calculate_angle_h(pi[0], pi[1], pf[0], pf[1])
|
110 |
+
if angle_h:
|
111 |
+
angles_h.append(abs(angle_h))
|
112 |
+
if line in v_lines:
|
113 |
+
angle_v = self.calculate_angle_v(pi[0], pi[1], pf[0], pf[1])
|
114 |
+
if angle_v:
|
115 |
+
angles_v.append(abs(angle_v))
|
116 |
+
|
117 |
+
|
118 |
+
if len(angles_h) > 0 and len(angles_v) > 0:
|
119 |
+
if np.mean(angles_h) < self.h_th and np.mean(angles_v) < self.v_th:
|
120 |
+
if n_right > n_left:
|
121 |
+
data = flip_annot_names(data, swap_top_bottom=False, swap_posts=False)
|
122 |
+
|
123 |
+
return {'image': sample['image'],
|
124 |
+
'data': data}
|
125 |
+
|
126 |
+
def calculate_angle_h(self, x1, y1, x2, y2):
|
127 |
+
if not x2 - x1 == 0:
|
128 |
+
slope = (y2 - y1) / (x2 - x1)
|
129 |
+
angle = math.atan(slope)
|
130 |
+
angle_degrees = math.degrees(angle)
|
131 |
+
return angle_degrees
|
132 |
+
else:
|
133 |
+
return None
|
134 |
+
def calculate_angle_v(self, x1, y1, x2, y2):
|
135 |
+
if not x2 - x1 == 0:
|
136 |
+
slope = (y2 - y1) / (x2 - x1)
|
137 |
+
angle = math.atan(1 / slope) if slope != 0 else math.pi / 2 # Avoid division by zero
|
138 |
+
angle_degrees = math.degrees(angle)
|
139 |
+
return angle_degrees
|
140 |
+
else:
|
141 |
+
return None
|
142 |
+
|
143 |
+
def compute_n_sides(self, data):
|
144 |
+
n_left, n_right = 0, 0
|
145 |
+
for line in data:
|
146 |
+
line_words = line.split()[:3]
|
147 |
+
if 'left' in line_words:
|
148 |
+
n_left += 1
|
149 |
+
elif 'right' in line_words:
|
150 |
+
n_right += 1
|
151 |
+
return n_left, n_right
|
152 |
+
|
153 |
+
def __repr__(self) -> str:
|
154 |
+
return f"{self.__class__.__name__}(v_th={self.v_th}, h_th={self.h_th})"
|
155 |
+
|
156 |
+
|
157 |
+
class RandomHorizontalFlip(torch.nn.Module):
|
158 |
+
def __init__(self, p=0.5):
|
159 |
+
super().__init__()
|
160 |
+
self.p = p
|
161 |
+
|
162 |
+
def forward(self, sample):
|
163 |
+
if torch.rand(1) < self.p:
|
164 |
+
image, data = sample['image'], sample['data']
|
165 |
+
image = f.hflip(image)
|
166 |
+
data = flip_annot_names(data)
|
167 |
+
for line in data:
|
168 |
+
for point in data[line]:
|
169 |
+
point['x'] = 1.0 - point['x']
|
170 |
+
|
171 |
+
return {'image': image,
|
172 |
+
'data': data}
|
173 |
+
else:
|
174 |
+
return {'image': sample['image'],
|
175 |
+
'data': sample['data']}
|
176 |
+
|
177 |
+
def __repr__(self) -> str:
|
178 |
+
return f"{self.__class__.__name__}(p={self.p})"
|
179 |
+
|
180 |
+
|
181 |
+
class AddGaussianNoise(torch.nn.Module):
|
182 |
+
def __init__(self, mean=0., std=2.):
|
183 |
+
self.std = std
|
184 |
+
self.mean = mean
|
185 |
+
|
186 |
+
def __call__(self, sample):
|
187 |
+
image = sample['image']
|
188 |
+
image += torch.randn(image.size()) * self.std + self.mean
|
189 |
+
image = torch.clip(image, 0, 1)
|
190 |
+
|
191 |
+
return {'image': image,
|
192 |
+
'data': sample['data']}
|
193 |
+
|
194 |
+
def __repr__(self):
|
195 |
+
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
|
196 |
+
|
197 |
+
|
198 |
+
class ColorJitter(torch.nn.Module):
|
199 |
+
|
200 |
+
def __init__(
|
201 |
+
self,
|
202 |
+
brightness: Union[float, Tuple[float, float]] = 0,
|
203 |
+
contrast: Union[float, Tuple[float, float]] = 0,
|
204 |
+
saturation: Union[float, Tuple[float, float]] = 0,
|
205 |
+
hue: Union[float, Tuple[float, float]] = 0,
|
206 |
+
) -> None:
|
207 |
+
super().__init__()
|
208 |
+
self.brightness = self._check_input(brightness, "brightness")
|
209 |
+
self.contrast = self._check_input(contrast, "contrast")
|
210 |
+
self.saturation = self._check_input(saturation, "saturation")
|
211 |
+
self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
|
212 |
+
|
213 |
+
@torch.jit.unused
|
214 |
+
def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
|
215 |
+
if isinstance(value, numbers.Number):
|
216 |
+
if value < 0:
|
217 |
+
raise ValueError(f"If {name} is a single number, it must be non negative.")
|
218 |
+
value = [center - float(value), center + float(value)]
|
219 |
+
if clip_first_on_zero:
|
220 |
+
value[0] = max(value[0], 0.0)
|
221 |
+
elif isinstance(value, (tuple, list)) and len(value) == 2:
|
222 |
+
value = [float(value[0]), float(value[1])]
|
223 |
+
else:
|
224 |
+
raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
|
225 |
+
|
226 |
+
if not bound[0] <= value[0] <= value[1] <= bound[1]:
|
227 |
+
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
|
228 |
+
|
229 |
+
# if value is 0 or (1., 1.) for brightness/contrast/saturation
|
230 |
+
# or (0., 0.) for hue, do nothing
|
231 |
+
if value[0] == value[1] == center:
|
232 |
+
return None
|
233 |
+
else:
|
234 |
+
return tuple(value)
|
235 |
+
|
236 |
+
@staticmethod
|
237 |
+
def get_params(
|
238 |
+
brightness: Optional[List[float]],
|
239 |
+
contrast: Optional[List[float]],
|
240 |
+
saturation: Optional[List[float]],
|
241 |
+
hue: Optional[List[float]],
|
242 |
+
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
|
243 |
+
"""Get the parameters for the randomized transform to be applied on image.
|
244 |
+
|
245 |
+
Args:
|
246 |
+
brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
|
247 |
+
uniformly. Pass None to turn off the transformation.
|
248 |
+
contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
|
249 |
+
uniformly. Pass None to turn off the transformation.
|
250 |
+
saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
|
251 |
+
uniformly. Pass None to turn off the transformation.
|
252 |
+
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
|
253 |
+
Pass None to turn off the transformation.
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
tuple: The parameters used to apply the randomized transform
|
257 |
+
along with their random order.
|
258 |
+
"""
|
259 |
+
fn_idx = torch.randperm(4)
|
260 |
+
|
261 |
+
b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
|
262 |
+
c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
|
263 |
+
s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
|
264 |
+
h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
|
265 |
+
|
266 |
+
return fn_idx, b, c, s, h
|
267 |
+
|
268 |
+
|
269 |
+
def forward(self, sample):
|
270 |
+
"""
|
271 |
+
Args:
|
272 |
+
img (PIL Image or Tensor): Input image.
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
PIL Image or Tensor: Color jittered image.
|
276 |
+
"""
|
277 |
+
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
|
278 |
+
self.brightness, self.contrast, self.saturation, self.hue
|
279 |
+
)
|
280 |
+
|
281 |
+
image = sample['image']
|
282 |
+
|
283 |
+
for fn_id in fn_idx:
|
284 |
+
if fn_id == 0 and brightness_factor is not None:
|
285 |
+
image = f.adjust_brightness(image, brightness_factor)
|
286 |
+
elif fn_id == 1 and contrast_factor is not None:
|
287 |
+
image = f.adjust_contrast(image, contrast_factor)
|
288 |
+
elif fn_id == 2 and saturation_factor is not None:
|
289 |
+
image = f.adjust_saturation(image, saturation_factor)
|
290 |
+
elif fn_id == 3 and hue_factor is not None:
|
291 |
+
image = f.adjust_hue(image, hue_factor)
|
292 |
+
|
293 |
+
return {'image': image,
|
294 |
+
'data': sample['data']}
|
295 |
+
|
296 |
+
|
297 |
+
def __repr__(self) -> str:
|
298 |
+
s = (
|
299 |
+
f"{self.__class__.__name__}("
|
300 |
+
f"brightness={self.brightness}"
|
301 |
+
f", contrast={self.contrast}"
|
302 |
+
f", saturation={self.saturation}"
|
303 |
+
f", hue={self.hue})"
|
304 |
+
)
|
305 |
+
return s
|
306 |
+
|
307 |
+
|
308 |
+
class Resize(torch.nn.Module):
|
309 |
+
def __init__(self, size, interpolation=InterpolationMode.BILINEAR):
|
310 |
+
super().__init__()
|
311 |
+
self.size = size
|
312 |
+
|
313 |
+
# Backward compatibility with integer value
|
314 |
+
if isinstance(interpolation, int):
|
315 |
+
warnings.warn(
|
316 |
+
"Argument interpolation should be of type InterpolationMode instead of int. "
|
317 |
+
"Please, use InterpolationMode enum."
|
318 |
+
)
|
319 |
+
interpolation = _interpolation_modes_from_int(interpolation)
|
320 |
+
|
321 |
+
self.interpolation = interpolation
|
322 |
+
|
323 |
+
def forward(self, sample):
|
324 |
+
image = sample["image"]
|
325 |
+
image = f.resize(image, self.size, self.interpolation)
|
326 |
+
|
327 |
+
return {'image': image,
|
328 |
+
'data': sample['data']}
|
329 |
+
|
330 |
+
|
331 |
+
def __repr__(self):
|
332 |
+
interpolate_str = self.interpolation.value
|
333 |
+
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
|
334 |
+
|
335 |
+
|
336 |
+
|
337 |
+
|
338 |
+
transforms = v2.Compose([
|
339 |
+
ToTensor(),
|
340 |
+
RandomHorizontalFlip(p=.5),
|
341 |
+
ColorJitter(brightness=(0.05), contrast=(0.05), saturation=(0.05), hue=(0.05)),
|
342 |
+
AddGaussianNoise(0, .1)
|
343 |
+
])
|
344 |
+
|
345 |
+
transforms_w_LR = v2.Compose([
|
346 |
+
ToTensor(),
|
347 |
+
RandomHorizontalFlip(p=.5),
|
348 |
+
LRAmbiguityFix(v_th=70, h_th=20),
|
349 |
+
ColorJitter(brightness=(0.05), contrast=(0.05), saturation=(0.05), hue=(0.05)),
|
350 |
+
AddGaussianNoise(0, .1)
|
351 |
+
])
|
352 |
+
|
353 |
+
no_transforms = v2.Compose([
|
354 |
+
ToTensor(),
|
355 |
+
])
|
356 |
+
|
357 |
+
no_transforms_w_LR = v2.Compose([
|
358 |
+
ToTensor(),
|
359 |
+
LRAmbiguityFix(v_th=70, h_th=20)
|
360 |
+
])
|
requirements.txt
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# # API Framework
|
2 |
+
# fastapi==0.104.1
|
3 |
+
# uvicorn[standard]==0.24.0
|
4 |
+
# python-multipart==0.0.6
|
5 |
+
# pydantic==2.5.0
|
6 |
+
|
7 |
+
# # Core dependencies (compatibles Python 3.10)
|
8 |
+
# numpy==1.24.3
|
9 |
+
# opencv-python-headless==4.8.1.78
|
10 |
+
# pillow==10.1.0
|
11 |
+
# scipy==1.11.4
|
12 |
+
# PyYAML==6.0.1
|
13 |
+
# lsq-ellipse==2.2.1
|
14 |
+
# shapely==2.0.2
|
15 |
+
|
16 |
+
# # PyTorch CPU (compatible Python 3.10)
|
17 |
+
# torch==2.1.0
|
18 |
+
# torchvision==0.16.0
|
19 |
+
|
20 |
+
# # Utilities
|
21 |
+
# tqdm==4.66.1
|
22 |
+
|
23 |
+
# requirements_clean.txt
|
24 |
+
fastapi
|
25 |
+
uvicorn[standard]
|
26 |
+
python-multipart
|
27 |
+
pydantic
|
28 |
+
|
29 |
+
# Core dependencies
|
30 |
+
numpy
|
31 |
+
opencv-python-headless
|
32 |
+
pillow
|
33 |
+
scipy
|
34 |
+
PyYAML
|
35 |
+
lsq-ellipse
|
36 |
+
shapely
|
37 |
+
|
38 |
+
# PyTorch CPU uniquement (beaucoup plus léger)
|
39 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
40 |
+
torch+cpu
|
41 |
+
torchvision+cpu
|
42 |
+
|
43 |
+
# Utilities
|
44 |
+
tqdm
|
run_api.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uvicorn
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
print("🚀 Démarrage de l'API Football Vision Calibration...")
|
5 |
+
print("📍 API accessible sur: http://localhost:8000")
|
6 |
+
print("📖 Documentation: http://localhost:8000/docs")
|
7 |
+
|
8 |
+
uvicorn.run(
|
9 |
+
"api:app", # ← Changement ici : string au lieu d'objet
|
10 |
+
host="0.0.0.0",
|
11 |
+
port=8000,
|
12 |
+
reload=True # Rechargement automatique en développement
|
13 |
+
)
|
scripts/eval_tswc.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
import glob
|
5 |
+
import torch
|
6 |
+
import zipfile
|
7 |
+
import argparse
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
13 |
+
from utils.utils_calib import FramebyFrameCalib
|
14 |
+
from model.metrics import calc_iou_part, calc_iou_whole_with_poly, calc_reproj_error, calc_proj_error
|
15 |
+
|
16 |
+
def parse_args():
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument("--root_dir", type=str, required=True)
|
19 |
+
parser.add_argument("--split", type=str, required=True)
|
20 |
+
parser.add_argument("--pred_file", type=str, required=True)
|
21 |
+
|
22 |
+
args = parser.parse_args()
|
23 |
+
return args
|
24 |
+
|
25 |
+
|
26 |
+
def get_homographies(file_paths):
|
27 |
+
npy_files = []
|
28 |
+
for file_path in file_paths:
|
29 |
+
directory_path = os.path.join(os.path.join(args.root_dir, "Annotations/80_95"), file_path)
|
30 |
+
if os.path.exists(directory_path):
|
31 |
+
files = os.listdir(directory_path)
|
32 |
+
npy_files.extend([os.path.join(directory_path, file) for file in files if file.endswith('.npy')])
|
33 |
+
|
34 |
+
npy_files = sorted(npy_files)
|
35 |
+
return npy_files
|
36 |
+
|
37 |
+
|
38 |
+
def make_file_name(file):
|
39 |
+
file = "TS-WorldCup/" + file.split("TS-WorldCup/")[-1]
|
40 |
+
splits = file.split('/')
|
41 |
+
side = splits[3]
|
42 |
+
match = splits[4]
|
43 |
+
image = splits[5]
|
44 |
+
frame = image.split('_homography.npy')[0]
|
45 |
+
return side + '-' + match + '-' + frame + '.json'
|
46 |
+
|
47 |
+
|
48 |
+
def pan_tilt_roll_to_orientation(pan, tilt, roll):
|
49 |
+
"""
|
50 |
+
Conversion from euler angles to orientation matrix.
|
51 |
+
:param pan:
|
52 |
+
:param tilt:
|
53 |
+
:param roll:
|
54 |
+
:return: orientation matrix
|
55 |
+
"""
|
56 |
+
Rpan = np.array([
|
57 |
+
[np.cos(pan), -np.sin(pan), 0],
|
58 |
+
[np.sin(pan), np.cos(pan), 0],
|
59 |
+
[0, 0, 1]])
|
60 |
+
Rroll = np.array([
|
61 |
+
[np.cos(roll), -np.sin(roll), 0],
|
62 |
+
[np.sin(roll), np.cos(roll), 0],
|
63 |
+
[0, 0, 1]])
|
64 |
+
Rtilt = np.array([
|
65 |
+
[1, 0, 0],
|
66 |
+
[0, np.cos(tilt), -np.sin(tilt)],
|
67 |
+
[0, np.sin(tilt), np.cos(tilt)]])
|
68 |
+
rotMat = np.dot(Rpan, np.dot(Rtilt, Rroll))
|
69 |
+
return rotMat
|
70 |
+
|
71 |
+
def get_sn_homography(cam_params: dict, batch_size=1):
|
72 |
+
# Extract relevant camera parameters from the dictionary
|
73 |
+
pan_degrees = cam_params['cam_params']['pan_degrees']
|
74 |
+
tilt_degrees = cam_params['cam_params']['tilt_degrees']
|
75 |
+
roll_degrees = cam_params['cam_params']['roll_degrees']
|
76 |
+
x_focal_length = cam_params['cam_params']['x_focal_length']
|
77 |
+
y_focal_length = cam_params['cam_params']['y_focal_length']
|
78 |
+
principal_point = np.array(cam_params['cam_params']['principal_point'])
|
79 |
+
position_meters = np.array(cam_params['cam_params']['position_meters'])
|
80 |
+
|
81 |
+
pan = pan_degrees * np.pi / 180.
|
82 |
+
tilt = tilt_degrees * np.pi / 180.
|
83 |
+
roll = roll_degrees * np.pi / 180.
|
84 |
+
|
85 |
+
rotation = np.array([[-np.sin(pan) * np.sin(roll) * np.cos(tilt) + np.cos(pan) * np.cos(roll),
|
86 |
+
np.sin(pan) * np.cos(roll) + np.sin(roll) * np.cos(pan) * np.cos(tilt), np.sin(roll) * np.sin(tilt)],
|
87 |
+
[-np.sin(pan) * np.cos(roll) * np.cos(tilt) - np.sin(roll) * np.cos(pan),
|
88 |
+
-np.sin(pan) * np.sin(roll) + np.cos(pan) * np.cos(roll) * np.cos(tilt), np.sin(tilt) * np.cos(roll)],
|
89 |
+
[np.sin(pan) * np.sin(tilt), -np.sin(tilt) * np.cos(pan), np.cos(tilt)]], dtype='float')
|
90 |
+
|
91 |
+
rotation = np.transpose(pan_tilt_roll_to_orientation(pan, tilt, roll))
|
92 |
+
|
93 |
+
def convert_homography_SN_to_WC14(H):
|
94 |
+
T = np.eye(3)
|
95 |
+
T[0, -1] = 105 / 2
|
96 |
+
T[1, -1] = 68 / 2
|
97 |
+
meter2yard = 1.09361
|
98 |
+
S = np.eye(3)
|
99 |
+
S[0, 0] = meter2yard
|
100 |
+
S[1, 1] = meter2yard
|
101 |
+
H_SN = S @ (T @ H)
|
102 |
+
return H_SN
|
103 |
+
|
104 |
+
def get_homography_by_index(homography_file):
|
105 |
+
homography = np.load(homography_file)
|
106 |
+
homography = homography / homography[2:3, 2:3]
|
107 |
+
return homography
|
108 |
+
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
args = parse_args()
|
112 |
+
|
113 |
+
missed = 0
|
114 |
+
iou_part_list, iou_whole_list = [], []
|
115 |
+
rep_err_list, proj_err_list = [], []
|
116 |
+
|
117 |
+
with open(args.root_dir + args.split + '.txt', 'r') as file:
|
118 |
+
# Read lines from the file and remove trailing newline characters
|
119 |
+
seqs = [line.strip() for line in file.readlines()]
|
120 |
+
|
121 |
+
homographies = get_homographies(seqs)
|
122 |
+
prediction_archive = zipfile.ZipFile(args.pred_file, 'r')
|
123 |
+
cam = FramebyFrameCalib(1280, 720, denormalize=True)
|
124 |
+
|
125 |
+
for h_gt in tqdm(homographies):
|
126 |
+
file_name = h_gt.split('/')[-1].split('.')[0]
|
127 |
+
pred_name = make_file_name(h_gt)
|
128 |
+
|
129 |
+
if pred_name not in prediction_archive.namelist():
|
130 |
+
missed += 1
|
131 |
+
continue
|
132 |
+
|
133 |
+
homography_gt = get_homography_by_index(h_gt)
|
134 |
+
final_dict = prediction_archive.read(pred_name)
|
135 |
+
final_dict = json.loads(final_dict.decode('utf-8'))
|
136 |
+
keypoints_dict = final_dict['kp']
|
137 |
+
lines_dict = final_dict['lines']
|
138 |
+
keypoints_dict = {int(key): value for key, value in keypoints_dict.items()}
|
139 |
+
lines_dict = {int(key): value for key, value in lines_dict.items()}
|
140 |
+
|
141 |
+
cam.update(keypoints_dict, lines_dict)
|
142 |
+
final_dict = cam.heuristic_voting_ground(refine_lines=True)
|
143 |
+
#homography_pred, ret = cam.get_homography_from_ground_plane(use_ransac=20, inverse=True, refine_lines=True)
|
144 |
+
if final_dict is None:
|
145 |
+
#if homography_pred is None:
|
146 |
+
missed += 1
|
147 |
+
continue
|
148 |
+
homography_pred = final_dict["homography"]
|
149 |
+
homography_pred = convert_homography_SN_to_WC14(homography_pred)
|
150 |
+
|
151 |
+
iou_p = calc_iou_part(homography_pred, homography_gt)
|
152 |
+
iou_w, _, _ = calc_iou_whole_with_poly(homography_pred, homography_gt)
|
153 |
+
rep_err = calc_reproj_error(homography_pred, homography_gt)
|
154 |
+
proj_err = calc_proj_error(homography_pred, homography_gt)
|
155 |
+
|
156 |
+
iou_part_list.append(iou_p)
|
157 |
+
iou_whole_list.append(iou_w)
|
158 |
+
rep_err_list.append(rep_err)
|
159 |
+
proj_err_list.append(proj_err)
|
160 |
+
|
161 |
+
|
162 |
+
print(f'Completeness: {1-missed/len(homographies)}')
|
163 |
+
print('IOU Part')
|
164 |
+
print(f'mean: {np.mean(iou_part_list)} \t median: {np.median(iou_part_list)}')
|
165 |
+
print('\nIOU Whole')
|
166 |
+
print(f'mean: {np.mean(iou_whole_list)} \t median: {np.median(iou_whole_list)}')
|
167 |
+
print('\nReprojection Err.')
|
168 |
+
print(f'mean: {np.mean(rep_err_list)} \t median: {np.median(rep_err_list)}')
|
169 |
+
print('\nProjection Err. (meters)')
|
170 |
+
print(f'mean: {np.mean(proj_err_list) * 0.9144} \t median: {np.median(proj_err_list) * 0.9144}')
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
|
scripts/eval_wc14.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
import glob
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import zipfile
|
8 |
+
import argparse
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from tqdm import tqdm
|
12 |
+
from typing import final
|
13 |
+
|
14 |
+
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
15 |
+
from utils.utils_calib import FramebyFrameCalib
|
16 |
+
from model.metrics import calc_iou_part, calc_iou_whole_with_poly, calc_reproj_error, calc_proj_error
|
17 |
+
|
18 |
+
def parse_args():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument("--root_dir", type=str, required=True)
|
21 |
+
parser.add_argument("--split", type=str, required=True)
|
22 |
+
parser.add_argument("--pred_file", type=str, required=True)
|
23 |
+
|
24 |
+
args = parser.parse_args()
|
25 |
+
return args
|
26 |
+
|
27 |
+
|
28 |
+
def pan_tilt_roll_to_orientation(pan, tilt, roll):
|
29 |
+
"""
|
30 |
+
Conversion from euler angles to orientation matrix.
|
31 |
+
:param pan:
|
32 |
+
:param tilt:
|
33 |
+
:param roll:
|
34 |
+
:return: orientation matrix
|
35 |
+
"""
|
36 |
+
Rpan = np.array([
|
37 |
+
[np.cos(pan), -np.sin(pan), 0],
|
38 |
+
[np.sin(pan), np.cos(pan), 0],
|
39 |
+
[0, 0, 1]])
|
40 |
+
Rroll = np.array([
|
41 |
+
[np.cos(roll), -np.sin(roll), 0],
|
42 |
+
[np.sin(roll), np.cos(roll), 0],
|
43 |
+
[0, 0, 1]])
|
44 |
+
Rtilt = np.array([
|
45 |
+
[1, 0, 0],
|
46 |
+
[0, np.cos(tilt), -np.sin(tilt)],
|
47 |
+
[0, np.sin(tilt), np.cos(tilt)]])
|
48 |
+
rotMat = np.dot(Rpan, np.dot(Rtilt, Rroll))
|
49 |
+
return rotMat
|
50 |
+
|
51 |
+
def get_sn_homography(cam_params: dict, batch_size=1):
|
52 |
+
# Extract relevant camera parameters from the dictionary
|
53 |
+
pan_degrees = cam_params['cam_params']['pan_degrees']
|
54 |
+
tilt_degrees = cam_params['cam_params']['tilt_degrees']
|
55 |
+
roll_degrees = cam_params['cam_params']['roll_degrees']
|
56 |
+
x_focal_length = cam_params['cam_params']['x_focal_length']
|
57 |
+
y_focal_length = cam_params['cam_params']['y_focal_length']
|
58 |
+
principal_point = np.array(cam_params['cam_params']['principal_point'])
|
59 |
+
position_meters = np.array(cam_params['cam_params']['position_meters'])
|
60 |
+
|
61 |
+
pan = pan_degrees * np.pi / 180.
|
62 |
+
tilt = tilt_degrees * np.pi / 180.
|
63 |
+
roll = roll_degrees * np.pi / 180.
|
64 |
+
|
65 |
+
rotation = np.array([[-np.sin(pan) * np.sin(roll) * np.cos(tilt) + np.cos(pan) * np.cos(roll),
|
66 |
+
np.sin(pan) * np.cos(roll) + np.sin(roll) * np.cos(pan) * np.cos(tilt), np.sin(roll) * np.sin(tilt)],
|
67 |
+
[-np.sin(pan) * np.cos(roll) * np.cos(tilt) - np.sin(roll) * np.cos(pan),
|
68 |
+
-np.sin(pan) * np.sin(roll) + np.cos(pan) * np.cos(roll) * np.cos(tilt), np.sin(tilt) * np.cos(roll)],
|
69 |
+
[np.sin(pan) * np.sin(tilt), -np.sin(tilt) * np.cos(pan), np.cos(tilt)]], dtype='float')
|
70 |
+
|
71 |
+
rotation = np.transpose(pan_tilt_roll_to_orientation(pan, tilt, roll))
|
72 |
+
|
73 |
+
def convert_homography_SN_to_WC14(H):
|
74 |
+
T = np.eye(3)
|
75 |
+
T[0, -1] = 105 / 2
|
76 |
+
T[1, -1] = 68 / 2
|
77 |
+
meter2yard = 1.09361
|
78 |
+
S = np.eye(3)
|
79 |
+
S[0, 0] = meter2yard
|
80 |
+
S[1, 1] = meter2yard
|
81 |
+
H_SN = S @ (T @ H)
|
82 |
+
return H_SN
|
83 |
+
|
84 |
+
def get_homography_by_index(homography_file):
|
85 |
+
with open(homography_file, 'r') as file:
|
86 |
+
lines = file.readlines()
|
87 |
+
matrix_elements = []
|
88 |
+
for line in lines:
|
89 |
+
matrix_elements.extend([float(element) for element in line.split()])
|
90 |
+
homography = np.array(matrix_elements).reshape((3, 3))
|
91 |
+
homography = homography / homography[2:3, 2:3]
|
92 |
+
return homography
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
args = parse_args()
|
96 |
+
|
97 |
+
missed = 0
|
98 |
+
iou_part_list, iou_whole_list = [], []
|
99 |
+
rep_err_list, proj_err_list = [], []
|
100 |
+
|
101 |
+
homographies = glob.glob(os.path.join(args.root_dir + args.split, "*.homographyMatrix"))
|
102 |
+
prediction_archive = zipfile.ZipFile(args.pred_file, 'r')
|
103 |
+
cam = FramebyFrameCalib(1280, 720, denormalize=True)
|
104 |
+
|
105 |
+
for h_gt in tqdm(homographies):
|
106 |
+
file_name = h_gt.split('/')[-1].split('.')[0]
|
107 |
+
pred_name = file_name + '.json'
|
108 |
+
|
109 |
+
if pred_name not in prediction_archive.namelist():
|
110 |
+
missed += 1
|
111 |
+
continue
|
112 |
+
|
113 |
+
homography_gt = get_homography_by_index(h_gt)
|
114 |
+
final_dict = prediction_archive.read(pred_name)
|
115 |
+
final_dict = json.loads(final_dict.decode('utf-8'))
|
116 |
+
keypoints_dict = final_dict['kp']
|
117 |
+
lines_dict = final_dict['lines']
|
118 |
+
keypoints_dict = {int(key): value for key, value in keypoints_dict.items()}
|
119 |
+
lines_dict = {int(key): value for key, value in lines_dict.items()}
|
120 |
+
|
121 |
+
cam.update(keypoints_dict, lines_dict)
|
122 |
+
final_dict = cam.heuristic_voting_ground(refine_lines=True)
|
123 |
+
|
124 |
+
if final_dict is None:
|
125 |
+
missed += 1
|
126 |
+
continue
|
127 |
+
|
128 |
+
homography_pred = final_dict["homography"]
|
129 |
+
homography_pred = convert_homography_SN_to_WC14(homography_pred)
|
130 |
+
|
131 |
+
iou_p = calc_iou_part(homography_pred, homography_gt)
|
132 |
+
iou_w, _, _ = calc_iou_whole_with_poly(homography_pred, homography_gt)
|
133 |
+
rep_err = calc_reproj_error(homography_pred, homography_gt)
|
134 |
+
proj_err = calc_proj_error(homography_pred, homography_gt)
|
135 |
+
|
136 |
+
iou_part_list.append(iou_p)
|
137 |
+
iou_whole_list.append(iou_w)
|
138 |
+
rep_err_list.append(rep_err)
|
139 |
+
proj_err_list.append(proj_err)
|
140 |
+
|
141 |
+
|
142 |
+
print(f'Completeness: {1-missed/len(homographies)}')
|
143 |
+
print('IOU Part')
|
144 |
+
print(f'mean: {np.mean(iou_part_list)} \t median: {np.median(iou_part_list)}')
|
145 |
+
print('\nIOU Whole')
|
146 |
+
print(f'mean: {np.mean(iou_whole_list)} \t median: {np.median(iou_whole_list)}')
|
147 |
+
print('\nReprojection Err.')
|
148 |
+
print(f'mean: {np.mean(rep_err_list)} \t median: {np.median(rep_err_list)}')
|
149 |
+
print('\nProjection Err. (meters)')
|
150 |
+
print(f'mean: {np.mean(proj_err_list) * 0.9144} \t median: {np.median(proj_err_list) * 0.9144}')
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
|
scripts/inference_sn.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
import glob
|
5 |
+
import yaml
|
6 |
+
import torch
|
7 |
+
import zipfile
|
8 |
+
import argparse
|
9 |
+
import warnings
|
10 |
+
import numpy as np
|
11 |
+
import torchvision.transforms as T
|
12 |
+
import torchvision.transforms.functional as f
|
13 |
+
|
14 |
+
from tqdm import tqdm
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
18 |
+
|
19 |
+
from model.cls_hrnet import get_cls_net
|
20 |
+
from model.cls_hrnet_l import get_cls_net as get_cls_net_l
|
21 |
+
from utils.utils_keypoints import KeypointsDB
|
22 |
+
from utils.utils_lines import LineKeypointsDB
|
23 |
+
from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, \
|
24 |
+
coords_to_dict, complete_keypoints
|
25 |
+
from utils.utils_calib import FramebyFrameCalib
|
26 |
+
|
27 |
+
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
28 |
+
warnings.filterwarnings("ignore", category=np.RankWarning)
|
29 |
+
|
30 |
+
|
31 |
+
def parse_args():
|
32 |
+
parser = argparse.ArgumentParser()
|
33 |
+
parser.add_argument("--cfg", type=str, required=True,
|
34 |
+
help="Path to the (kp model) configuration file")
|
35 |
+
parser.add_argument("--cfg_l", type=str, required=True,
|
36 |
+
help="Path to the (line model) configuration file")
|
37 |
+
parser.add_argument("--root_dir", type=str, required=True,
|
38 |
+
help="Root directory")
|
39 |
+
parser.add_argument("--split", type=str, required=True,
|
40 |
+
help="Dataset split")
|
41 |
+
parser.add_argument("--save_dir", type=str, required=True,
|
42 |
+
help="Saving file path")
|
43 |
+
parser.add_argument("--weights_kp", type=str, required=True,
|
44 |
+
help="Model (keypoints) weigths to use")
|
45 |
+
parser.add_argument("--weights_line", type=str, required=True,
|
46 |
+
help="Model (lines) weigths to use")
|
47 |
+
parser.add_argument("--cuda", type=str, default="cuda:0",
|
48 |
+
help="CUDA device index (default: 'cuda:0')")
|
49 |
+
parser.add_argument("--kp_th", type=float, default="0.1")
|
50 |
+
parser.add_argument("--line_th", type=float, default="0.1")
|
51 |
+
parser.add_argument("--max_reproj_err", type=float, default="50")
|
52 |
+
parser.add_argument("--main_cam_only", action='store_true')
|
53 |
+
parser.add_argument('--use_gt', action='store_true', help='Use ground truth annotations (default: False)')
|
54 |
+
|
55 |
+
args = parser.parse_args()
|
56 |
+
return args
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
args = parse_args()
|
61 |
+
|
62 |
+
files = glob.glob(os.path.join(args.root_dir + args.split, "*.jpg"))
|
63 |
+
|
64 |
+
if args.main_cam_only:
|
65 |
+
cam_info = json.load(open(args.root_dir + args.split + '/match_info_cam_gt.json'))
|
66 |
+
files = [file for file in files if file.split('/')[-1] in cam_info.keys()]
|
67 |
+
files = [file for file in files if cam_info[file.split('/')[-1]]['camera'] == 'Main camera center']
|
68 |
+
# files = [file for file in files if int(match_info[file.split('/')[-1]]['ms_time']) == \
|
69 |
+
# int(match_info[file.split('/')[-1]]['replay_time'])]
|
70 |
+
|
71 |
+
if args.main_cam_only:
|
72 |
+
zip_name = args.save_dir + args.split + '_main.zip'
|
73 |
+
else:
|
74 |
+
zip_name = args.save_dir + args.split + '.zip'
|
75 |
+
|
76 |
+
if args.use_gt:
|
77 |
+
if args.main_cam_only:
|
78 |
+
zip_name_pred = args.save_dir + args.split + '_main_gt.zip'
|
79 |
+
else:
|
80 |
+
zip_name_pred = args.save_dir + args.split + '_gt.zip'
|
81 |
+
else:
|
82 |
+
if args.main_cam_only:
|
83 |
+
zip_name_pred = args.save_dir + args.split + '_main_pred.zip'
|
84 |
+
else:
|
85 |
+
zip_name_pred = args.save_dir + args.split + '_pred.zip'
|
86 |
+
|
87 |
+
print(f"Saving results in {args.save_dir}")
|
88 |
+
print(f"file: {zip_name_pred}")
|
89 |
+
|
90 |
+
if args.use_gt:
|
91 |
+
transform = T.Resize((540, 960))
|
92 |
+
cam = FramebyFrameCalib(960, 540, denormalize=True)
|
93 |
+
|
94 |
+
with zipfile.ZipFile(zip_name_pred, 'w') as zip_file:
|
95 |
+
samples, complete = 0., 0.
|
96 |
+
for file in tqdm(files, desc="Processing Images"):
|
97 |
+
image = Image.open(file)
|
98 |
+
file_name = file.split('/')[-1].split('.')[0]
|
99 |
+
samples += 1
|
100 |
+
|
101 |
+
json_path = file.split('.')[0] + ".json"
|
102 |
+
f = open(json_path)
|
103 |
+
data = json.load(f)
|
104 |
+
|
105 |
+
kp_db = KeypointsDB(data, image)
|
106 |
+
line_db = LineKeypointsDB(data, image)
|
107 |
+
heatmaps, _ = kp_db.get_tensor_w_mask()
|
108 |
+
heatmaps = torch.tensor(heatmaps).unsqueeze(0)
|
109 |
+
heatmaps_l = line_db.get_tensor()
|
110 |
+
heatmaps_l = torch.tensor(heatmaps_l).unsqueeze(0)
|
111 |
+
kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:, :-1, :, :])
|
112 |
+
line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:, :-1, :, :])
|
113 |
+
kp_dict = coords_to_dict(kp_coords, threshold=0.01)
|
114 |
+
lines_dict = coords_to_dict(line_coords, threshold=0.01)
|
115 |
+
|
116 |
+
cam.update(kp_dict, lines_dict)
|
117 |
+
final_params_dict = cam.heuristic_voting()
|
118 |
+
# final_params_dict = cam.calibrate(5)
|
119 |
+
|
120 |
+
if final_params_dict:
|
121 |
+
complete += 1
|
122 |
+
cam_params = final_params_dict['cam_params']
|
123 |
+
print("heheheheheheh")
|
124 |
+
json_data = json.dumps(cam_params)
|
125 |
+
zip_file.writestr(f"camera_{file_name}.json", json_data)
|
126 |
+
|
127 |
+
else:
|
128 |
+
device = torch.device(args.cuda if torch.cuda.is_available() else 'cpu')
|
129 |
+
cfg = yaml.safe_load(open(args.cfg, 'r'))
|
130 |
+
cfg_l = yaml.safe_load(open(args.cfg_l, 'r'))
|
131 |
+
|
132 |
+
loaded_state = torch.load(args.weights_kp, map_location=device)
|
133 |
+
model = get_cls_net(cfg)
|
134 |
+
model.load_state_dict(loaded_state)
|
135 |
+
model.to(device)
|
136 |
+
model.eval()
|
137 |
+
|
138 |
+
loaded_state_l = torch.load(args.weights_line, map_location=device)
|
139 |
+
model_l = get_cls_net_l(cfg_l)
|
140 |
+
model_l.load_state_dict(loaded_state_l)
|
141 |
+
model_l.to(device)
|
142 |
+
model_l.eval()
|
143 |
+
|
144 |
+
transform = T.Resize((540, 960))
|
145 |
+
cam = FramebyFrameCalib(960, 540)
|
146 |
+
|
147 |
+
with zipfile.ZipFile(zip_name_pred, 'w') as zip_file:
|
148 |
+
samples, complete = 0., 0.
|
149 |
+
for file in tqdm(files, desc="Processing Images"):
|
150 |
+
image = Image.open(file)
|
151 |
+
file_name = file.split('/')[-1].split('.')[0]
|
152 |
+
samples += 1
|
153 |
+
|
154 |
+
with torch.no_grad():
|
155 |
+
image = f.to_tensor(image).float().to(device).unsqueeze(0)
|
156 |
+
image = image if image.size()[-1] == 960 else transform(image)
|
157 |
+
b, c, h, w = image.size()
|
158 |
+
heatmaps = model(image)
|
159 |
+
heatmaps_l = model_l(image)
|
160 |
+
|
161 |
+
kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:, :-1, :, :])
|
162 |
+
line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:, :-1, :, :])
|
163 |
+
kp_dict = coords_to_dict(kp_coords, threshold=args.kp_th)
|
164 |
+
lines_dict = coords_to_dict(line_coords, threshold=args.line_th)
|
165 |
+
kp_dict, lines_dict = complete_keypoints(kp_dict[0], lines_dict[0], w=w, h=h)
|
166 |
+
|
167 |
+
cam.update(kp_dict, lines_dict)
|
168 |
+
final_params_dict = cam.heuristic_voting(refine_lines=True)
|
169 |
+
|
170 |
+
if final_params_dict:
|
171 |
+
if final_params_dict['rep_err'] <= args.max_reproj_err:
|
172 |
+
complete += 1
|
173 |
+
cam_params = final_params_dict['cam_params']
|
174 |
+
json_data = json.dumps(cam_params)
|
175 |
+
zip_file.writestr(f"camera_{file_name}.json", json_data)
|
176 |
+
|
177 |
+
with zipfile.ZipFile(zip_name, 'w') as zip_file:
|
178 |
+
for file in tqdm(files, desc="Processing Images"):
|
179 |
+
file_name = file.split('/')[-1].split('.')[0]
|
180 |
+
data = json.load(open(file.split('.')[0] + ".json"))
|
181 |
+
json_data = json.dumps(data)
|
182 |
+
zip_file.writestr(f"{file_name}.json", json_data)
|
183 |
+
|
184 |
+
print(f'Completed {complete} / {samples}')
|
185 |
+
|
186 |
+
|
scripts/inference_tswc.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
import glob
|
5 |
+
import yaml
|
6 |
+
import torch
|
7 |
+
import zipfile
|
8 |
+
import argparse
|
9 |
+
import warnings
|
10 |
+
import numpy as np
|
11 |
+
import torchvision.transforms as T
|
12 |
+
import torchvision.transforms.functional as f
|
13 |
+
|
14 |
+
from tqdm import tqdm
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
18 |
+
from model.cls_hrnet import get_cls_net
|
19 |
+
from model.cls_hrnet_l import get_cls_net as get_cls_net_l
|
20 |
+
from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, \
|
21 |
+
complete_keypoints, coords_to_dict
|
22 |
+
from utils.utils_keypoints import KeypointsDB
|
23 |
+
from utils.utils_lines import LineKeypointsDB
|
24 |
+
|
25 |
+
|
26 |
+
def parse_args():
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
parser.add_argument("--cfg", type=str, required=True,
|
29 |
+
help="Path to the (kp model) configuration file")
|
30 |
+
parser.add_argument("--cfg_l", type=str, required=True,
|
31 |
+
help="Path to the (line model) configuration file")
|
32 |
+
parser.add_argument("--root_dir", type=str, required=True,
|
33 |
+
help="Root directory")
|
34 |
+
parser.add_argument("--split", type=str, required=True,
|
35 |
+
help="Dataset split")
|
36 |
+
parser.add_argument("--save_dir", type=str, required=True,
|
37 |
+
help="Root directory")
|
38 |
+
parser.add_argument("--weights_kp", type=str, required=True,
|
39 |
+
help="Model (keypoints) weigths to use")
|
40 |
+
parser.add_argument("--weights_line", type=str, required=True,
|
41 |
+
help="Model (lines) weigths to use")
|
42 |
+
parser.add_argument("--cuda", type=str, default="cuda:0",
|
43 |
+
help="CUDA device index (default: 'cuda:0')")
|
44 |
+
parser.add_argument("--kp_th", type=float, default="0.1")
|
45 |
+
parser.add_argument("--line_th", type=float, default="0.1")
|
46 |
+
parser.add_argument("--batch", type=int, default=1, help="Batch size")
|
47 |
+
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers")
|
48 |
+
|
49 |
+
|
50 |
+
args = parser.parse_args()
|
51 |
+
return args
|
52 |
+
|
53 |
+
|
54 |
+
def get_files(file_paths):
|
55 |
+
jpg_files = []
|
56 |
+
for file_path in file_paths:
|
57 |
+
directory_path = os.path.join(os.path.join(args.root_dir, "Dataset/80_95"), file_path)
|
58 |
+
if os.path.exists(directory_path):
|
59 |
+
files = os.listdir(directory_path)
|
60 |
+
jpg_files.extend([os.path.join(directory_path, file) for file in files if file.endswith('.jpg')])
|
61 |
+
|
62 |
+
jpg_files = sorted(jpg_files)
|
63 |
+
return jpg_files
|
64 |
+
|
65 |
+
def get_homographies(file_paths):
|
66 |
+
npy_files = []
|
67 |
+
for file_path in file_paths:
|
68 |
+
directory_path = os.path.join(os.path.join(args.root_dir, "Annotations/80_95"), file_path)
|
69 |
+
if os.path.exists(directory_path):
|
70 |
+
files = os.listdir(directory_path)
|
71 |
+
npy_files.extend([os.path.join(directory_path, file) for file in files if file.endswith('.npy')])
|
72 |
+
|
73 |
+
npy_files = sorted(npy_files)
|
74 |
+
return npy_files
|
75 |
+
|
76 |
+
|
77 |
+
def make_file_name(file):
|
78 |
+
file = "TS-WorldCup/" + file.split("TS-WorldCup/")[-1]
|
79 |
+
splits = file.split('/')
|
80 |
+
side = splits[3]
|
81 |
+
match = splits[4]
|
82 |
+
image = splits[5]
|
83 |
+
frame = 'IMG_' + image.split('.')[0].split('_')[-1]
|
84 |
+
return side + '-' + match + '-' + frame
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
args = parse_args()
|
89 |
+
|
90 |
+
with open(args.root_dir + args.split + '.txt', 'r') as file:
|
91 |
+
# Read lines from the file and remove trailing newline characters
|
92 |
+
seqs = [line.strip() for line in file.readlines()]
|
93 |
+
|
94 |
+
files = get_files(seqs)
|
95 |
+
homographies = get_homographies(seqs)
|
96 |
+
|
97 |
+
zip_name_pred = args.save_dir + args.split + '_pred.zip'
|
98 |
+
|
99 |
+
device = torch.device(args.cuda if torch.cuda.is_available() else 'cpu')
|
100 |
+
cfg = yaml.safe_load(open(args.cfg, 'r'))
|
101 |
+
cfg_l = yaml.safe_load(open(args.cfg_l, 'r'))
|
102 |
+
|
103 |
+
loaded_state = torch.load(args.weights_kp, map_location=device)
|
104 |
+
model = get_cls_net(cfg)
|
105 |
+
model.load_state_dict(loaded_state)
|
106 |
+
model.to(device)
|
107 |
+
model.eval()
|
108 |
+
|
109 |
+
loaded_state_l = torch.load(args.weights_line, map_location=device)
|
110 |
+
model_l = get_cls_net_l(cfg_l)
|
111 |
+
model_l.load_state_dict(loaded_state_l)
|
112 |
+
model_l.to(device)
|
113 |
+
model_l.eval()
|
114 |
+
|
115 |
+
transform = T.Resize((540, 960))
|
116 |
+
|
117 |
+
with zipfile.ZipFile(zip_name_pred, 'w') as zip_file:
|
118 |
+
for count in tqdm(range(len(files)), desc="Processing Images"):
|
119 |
+
image = Image.open(files[count])
|
120 |
+
image = f.to_tensor(image).float().to(device).unsqueeze(0)
|
121 |
+
image = image if image.size()[-1] == 960 else transform(image)
|
122 |
+
b, c, h, w = image.size()
|
123 |
+
|
124 |
+
|
125 |
+
with torch.no_grad():
|
126 |
+
heatmaps = model(image)
|
127 |
+
heatmaps_l = model_l(image)
|
128 |
+
|
129 |
+
kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
|
130 |
+
line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
|
131 |
+
kp_dict = coords_to_dict(kp_coords, threshold=args.kp_th, ground_plane_only=True)
|
132 |
+
lines_dict = coords_to_dict(line_coords, threshold=args.line_th, ground_plane_only=True)
|
133 |
+
final_kp_dict, final_lines_dict = complete_keypoints(kp_dict[0], lines_dict[0],
|
134 |
+
w=w, h=h, normalize=True)
|
135 |
+
final_dict = {'kp': final_kp_dict, 'lines': final_lines_dict}
|
136 |
+
|
137 |
+
json_data = json.dumps(final_dict)
|
138 |
+
zip_file.writestr(f"{make_file_name(files[count])}.json", json_data)
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
|
scripts/inference_wc14.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
import glob
|
5 |
+
from typing import final
|
6 |
+
|
7 |
+
import yaml
|
8 |
+
import torch
|
9 |
+
import zipfile
|
10 |
+
import argparse
|
11 |
+
import warnings
|
12 |
+
import numpy as np
|
13 |
+
import torchvision.transforms as T
|
14 |
+
import torchvision.transforms.functional as f
|
15 |
+
|
16 |
+
from tqdm import tqdm
|
17 |
+
from PIL import Image
|
18 |
+
|
19 |
+
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
20 |
+
from model.cls_hrnet import get_cls_net
|
21 |
+
from model.cls_hrnet_l import get_cls_net as get_cls_net_l
|
22 |
+
from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, \
|
23 |
+
complete_keypoints, coords_to_dict
|
24 |
+
from utils.utils_keypoints import KeypointsDB
|
25 |
+
from utils.utils_lines import LineKeypointsDB
|
26 |
+
|
27 |
+
|
28 |
+
def parse_args():
|
29 |
+
parser = argparse.ArgumentParser()
|
30 |
+
parser.add_argument("--cfg", type=str, required=True,
|
31 |
+
help="Path to the (kp model) configuration file")
|
32 |
+
parser.add_argument("--cfg_l", type=str, required=True,
|
33 |
+
help="Path to the (line model) configuration file")
|
34 |
+
parser.add_argument("--root_dir", type=str, required=True,
|
35 |
+
help="Root directory")
|
36 |
+
parser.add_argument("--split", type=str, required=True,
|
37 |
+
help="Dataset split")
|
38 |
+
parser.add_argument("--save_dir", type=str, required=True,
|
39 |
+
help="Root directory")
|
40 |
+
parser.add_argument("--weights_kp", type=str, required=True,
|
41 |
+
help="Model (keypoints) weigths to use")
|
42 |
+
parser.add_argument("--weights_line", type=str, required=True,
|
43 |
+
help="Model (lines) weigths to use")
|
44 |
+
parser.add_argument("--cuda", type=str, default="cuda:0",
|
45 |
+
help="CUDA device index (default: 'cuda:0')")
|
46 |
+
parser.add_argument("--kp_th", type=float, default="0.1")
|
47 |
+
parser.add_argument("--line_th", type=float, default="0.1")
|
48 |
+
parser.add_argument("--batch", type=int, default=1, help="Batch size")
|
49 |
+
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers")
|
50 |
+
parser.add_argument('--use_gt', action='store_true', help='Use ground truth (default: False)')
|
51 |
+
|
52 |
+
|
53 |
+
args = parser.parse_args()
|
54 |
+
return args
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
args = parse_args()
|
60 |
+
|
61 |
+
files = glob.glob(os.path.join(args.root_dir + args.split, "*.jpg"))
|
62 |
+
|
63 |
+
if args.use_gt:
|
64 |
+
zip_name_pred = args.save_dir + args.split + '_gt.zip'
|
65 |
+
else:
|
66 |
+
zip_name_pred = args.save_dir + args.split + '_pred.zip'
|
67 |
+
|
68 |
+
|
69 |
+
if args.use_gt:
|
70 |
+
device = torch.device(args.cuda if torch.cuda.is_available() else 'cpu')
|
71 |
+
|
72 |
+
with zipfile.ZipFile(zip_name_pred, 'w') as zip_file:
|
73 |
+
for file in tqdm(files, desc="Processing Images"):
|
74 |
+
image = Image.open(file)
|
75 |
+
w, h = image.size
|
76 |
+
|
77 |
+
homography_file = args.root_dir + args.split + '/' + \
|
78 |
+
file.split('/')[-1].split('.')[0] + '.homographyMatrix'
|
79 |
+
|
80 |
+
json_path = file.split('.')[0] + ".json"
|
81 |
+
f = open(json_path)
|
82 |
+
data = json.load(f)
|
83 |
+
kp_db = KeypointsDB(data, image)
|
84 |
+
line_db = LineKeypointsDB(data, image)
|
85 |
+
heatmaps, _ = kp_db.get_tensor_w_mask()
|
86 |
+
heatmaps = torch.tensor(heatmaps).unsqueeze(0)
|
87 |
+
heatmaps_l = line_db.get_tensor()
|
88 |
+
heatmaps_l = torch.tensor(heatmaps_l).unsqueeze(0)
|
89 |
+
kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
|
90 |
+
line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
|
91 |
+
kp_dict = coords_to_dict(kp_coords, threshold=0.1)
|
92 |
+
lines_dict = coords_to_dict(line_coords, threshold=0.1)
|
93 |
+
final_kp_dict = complete_keypoints(kp_dict, lines_dict, w=w, h=h, normalize=True)
|
94 |
+
final_dict = {'kp': final_kp_dict, 'lines': lines_dict}
|
95 |
+
|
96 |
+
json_data = json.dumps(final_dict)
|
97 |
+
zip_file.writestr(f"{file.split('/')[-1].split('.')[0]}.json", json_data)
|
98 |
+
|
99 |
+
|
100 |
+
else:
|
101 |
+
device = torch.device(args.cuda if torch.cuda.is_available() else 'cpu')
|
102 |
+
cfg = yaml.safe_load(open(args.cfg, 'r'))
|
103 |
+
cfg_l = yaml.safe_load(open(args.cfg_l, 'r'))
|
104 |
+
|
105 |
+
loaded_state = torch.load(args.weights_kp, map_location=device)
|
106 |
+
model = get_cls_net(cfg)
|
107 |
+
model.load_state_dict(loaded_state)
|
108 |
+
model.to(device)
|
109 |
+
model.eval()
|
110 |
+
|
111 |
+
loaded_state_l = torch.load(args.weights_line, map_location=device)
|
112 |
+
model_l = get_cls_net_l(cfg_l)
|
113 |
+
model_l.load_state_dict(loaded_state_l)
|
114 |
+
model_l.to(device)
|
115 |
+
model_l.eval()
|
116 |
+
|
117 |
+
transform = T.Resize((540, 960))
|
118 |
+
|
119 |
+
with zipfile.ZipFile(zip_name_pred, 'w') as zip_file:
|
120 |
+
for file in tqdm(files, desc="Processing Images"):
|
121 |
+
image = Image.open(file)
|
122 |
+
image = f.to_tensor(image).float().to(device).unsqueeze(0)
|
123 |
+
image = image if image.size()[-1] == 960 else transform(image)
|
124 |
+
b, c, h, w = image.size()
|
125 |
+
|
126 |
+
homography_file = args.root_dir + args.split + '/' + \
|
127 |
+
file.split('/')[-1].split('.')[0] + '.homographyMatrix'
|
128 |
+
|
129 |
+
with torch.no_grad():
|
130 |
+
heatmaps = model(image)
|
131 |
+
heatmaps_l = model_l(image)
|
132 |
+
|
133 |
+
kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
|
134 |
+
line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
|
135 |
+
kp_dict = coords_to_dict(kp_coords, threshold=args.kp_th, ground_plane_only=True)
|
136 |
+
lines_dict = coords_to_dict(line_coords, threshold=args.line_th, ground_plane_only=True)
|
137 |
+
final_kp_dict, final_lines_dict = complete_keypoints(kp_dict[0], lines_dict[0],
|
138 |
+
w=w, h=h, normalize=True)
|
139 |
+
final_dict = {'kp': final_kp_dict, 'lines': final_lines_dict}
|
140 |
+
|
141 |
+
json_data = json.dumps(final_dict)
|
142 |
+
zip_file.writestr(f"{file.split('/')[-1].split('.')[0]}.json", json_data)
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
|
scripts/run_pipeline_sn22.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Set parameters
|
4 |
+
ROOT_DIR="calibration/"
|
5 |
+
SPLIT="test"
|
6 |
+
CFG="config/hrnetv2_w48.yaml"
|
7 |
+
CFG_L="config/hrnetv2_w48_l.yaml"
|
8 |
+
WEIGHTS_KP="weights/SV_kp"
|
9 |
+
WEIGHTS_L="weights/SV_lines"
|
10 |
+
SAVE_DIR="inference/inference_3D/inference_sn22/"
|
11 |
+
DEVICE="cuda:0"
|
12 |
+
KP_TH=0.1611
|
13 |
+
LINE_TH=0.3434
|
14 |
+
MAX_REPROJ_ERR=57
|
15 |
+
GT_FILE="${SAVE_DIR}${SPLIT}_main.zip"
|
16 |
+
PRED_FILE="${SAVE_DIR}${SPLIT}_main_pred.zip"
|
17 |
+
|
18 |
+
|
19 |
+
# Run inference script
|
20 |
+
python scripts/inference_sn.py --cfg $CFG --cfg_l $CFG_L --weights_kp $WEIGHTS_KP --weights_line $WEIGHTS_L --root_dir $ROOT_DIR --split $SPLIT --save_dir $SAVE_DIR --kp_th $KP_TH --line_th $LINE_TH --max_reproj_err $MAX_REPROJ_ERR --cuda $DEVICE --main_cam_only
|
21 |
+
|
22 |
+
# Run evaluation script
|
23 |
+
python sn_calibration/src/evalai_camera.py -s $GT_FILE -p $PRED_FILE
|
scripts/run_pipeline_sn23.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Set parameters
|
4 |
+
ROOT_DIR="datasets/calibration-2023/"
|
5 |
+
SPLIT="test"
|
6 |
+
CFG="config/hrnetv2_w48.yaml"
|
7 |
+
CFG_L="config/hrnetv2_w48_l.yaml"
|
8 |
+
WEIGHTS_KP="weights/MV_kp"
|
9 |
+
WEIGHTS_L="weights/MV_lines"
|
10 |
+
SAVE_DIR="inference/inference_3D/inference_sn23/"
|
11 |
+
DEVICE="cuda:0"
|
12 |
+
KP_TH=0.0712
|
13 |
+
LINE_TH=0.2571
|
14 |
+
MAX_REPROJ_ERR=38
|
15 |
+
GT_FILE="${SAVE_DIR}${SPLIT}.zip"
|
16 |
+
PRED_FILE="${SAVE_DIR}${SPLIT}_pred.zip"
|
17 |
+
|
18 |
+
|
19 |
+
# Run inference script
|
20 |
+
python scripts/inference_sn.py --cfg $CFG --cfg_l $CFG_L --weights_kp $WEIGHTS_KP --weights_line $WEIGHTS_L --root_dir $ROOT_DIR --split $SPLIT --save_dir $SAVE_DIR --kp_th $KP_TH --line_th $LINE_TH --max_reproj_err $MAX_REPROJ_ERR --cuda $DEVICE
|
21 |
+
|
22 |
+
# Run evaluation script
|
23 |
+
python sn_calibration/src/evalai_camera.py -s $GT_FILE -p $PRED_FILE
|
scripts/run_pipeline_tswc.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Set parameters
|
4 |
+
ROOT_DIR="datasets/TS-WorldCup/"
|
5 |
+
SPLIT="test"
|
6 |
+
CFG="config/hrnetv2_w48.yaml"
|
7 |
+
CFG_L="config/hrnetv2_w48_l.yaml"
|
8 |
+
WEIGHTS_KP="weights/SV_FT_TSWC_kp"
|
9 |
+
WEIGHTS_L="weights/SV_FT_TSWC_lines"
|
10 |
+
SAVE_DIR="inference/inference_2D/inference_tswc/"
|
11 |
+
KP_TH=0.2432
|
12 |
+
LINE_TH=0.8482
|
13 |
+
DEVICE="cuda:0"
|
14 |
+
PRED_FILE="${SAVE_DIR}${SPLIT}_pred.zip"
|
15 |
+
|
16 |
+
# Run inference script
|
17 |
+
python scripts/inference_tswc.py --cfg $CFG --cfg_l $CFG_L --weights_kp $WEIGHTS_KP --weights_line $WEIGHTS_L --root_dir $ROOT_DIR --split $SPLIT --save_dir $SAVE_DIR --kp_th $KP_TH --line_th $LINE_TH --cuda $DEVICE
|
18 |
+
|
19 |
+
# Run evaluation script
|
20 |
+
python scripts/eval_tswc.py --root_dir $ROOT_DIR --split $SPLIT --pred_file $PRED_FILE
|
scripts/run_pipeline_wc14.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Set parameters
|
4 |
+
ROOT_DIR="datasets/WC-2014/"
|
5 |
+
SPLIT="test"
|
6 |
+
CFG="config/hrnetv2_w48.yaml"
|
7 |
+
CFG_L="config/hrnetv2_w48_l.yaml"
|
8 |
+
WEIGHTS_KP="weights/SV_FT_WC14_kp"
|
9 |
+
WEIGHTS_L="weights/SV_FT_WC14_lines"
|
10 |
+
SAVE_DIR="inference/inference_2D/inference_wc14/"
|
11 |
+
KP_TH=0.1274
|
12 |
+
LINE_TH=0.1439
|
13 |
+
DEVICE="cuda:0"
|
14 |
+
PRED_FILE="${SAVE_DIR}${SPLIT}_pred.zip"
|
15 |
+
|
16 |
+
# Run inference script
|
17 |
+
python scripts/inference_wc14.py --cfg $CFG --cfg_l $CFG_L --weights_kp $WEIGHTS_KP --weights_line $WEIGHTS_L --root_dir $ROOT_DIR --split $SPLIT --save_dir $SAVE_DIR --kp_th $KP_TH --line_th $LINE_TH --cuda $DEVICE
|
18 |
+
|
19 |
+
# Run evaluation script
|
20 |
+
python scripts/eval_wc14.py --root_dir $ROOT_DIR --split $SPLIT --pred_file $PRED_FILE
|
scripts/run_pipeline_wc14_3D.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Set parameters
|
4 |
+
ROOT_DIR="datasets/WC-2014/"
|
5 |
+
SPLIT="test"
|
6 |
+
CFG="config/hrnetv2_w48.yaml"
|
7 |
+
CFG_L="config/hrnetv2_w48_l.yaml"
|
8 |
+
WEIGHTS_KP="weights/SV_kp"
|
9 |
+
WEIGHTS_L="weights/SV_lines"
|
10 |
+
SAVE_DIR="inference/inference_3D/inference_wc14/"
|
11 |
+
DEVICE="cuda:0"
|
12 |
+
KP_TH=0.0070
|
13 |
+
LINE_TH=0.1513
|
14 |
+
MAX_REPROJ_ERR=83
|
15 |
+
GT_FILE="${SAVE_DIR}${SPLIT}_main.zip"
|
16 |
+
PRED_FILE="${SAVE_DIR}${SPLIT}_main_pred.zip"
|
17 |
+
|
18 |
+
|
19 |
+
# Run inference script
|
20 |
+
python scripts/inference_sn.py --cfg $CFG --cfg_l $CFG_L --weights_kp $WEIGHTS_KP --weights_line $WEIGHTS_L --root_dir $ROOT_DIR --split $SPLIT --save_dir $SAVE_DIR --kp_th $KP_TH --line_th $LINE_TH --max_reproj_err $MAX_REPROJ_ERR --cuda $DEVICE --main_cam_only
|
21 |
+
|
22 |
+
# Run evaluation script
|
23 |
+
python sn_calibration/src/evalai_camera.py -s $GT_FILE -p $PRED_FILE
|
sn_calibration/ChallengeRules.md
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Guidelines for the Calibration challenge
|
2 |
+
|
3 |
+
The 2nd [Calibration challenge]() will be held at the
|
4 |
+
official [CVSports Workshop](https://vap.aau.dk/cvsports/) at CVPR 2023!
|
5 |
+
Subscribe (watch) the repo to receive the latest info regarding timeline and prizes!
|
6 |
+
|
7 |
+
We provide an [evaluation server](https://eval.ai/web/challenges/challenge-page/1946/overview) for anyone competing in the challenge.
|
8 |
+
This evaluation server handles predictions for the open **test** set and the segregated **challenge** set.
|
9 |
+
|
10 |
+
Winners will be announced at CVSports Workshop at CVPR 2022.
|
11 |
+
Prizes 💲💲💲 include $1000 cash award, sponsored by [EVS Broadcast Equipment](https://evs.com/).
|
12 |
+
|
13 |
+
|
14 |
+
## Who can participate / How to participate?
|
15 |
+
|
16 |
+
- Any individual can participate to the challenge, except the organizers.
|
17 |
+
- The participants are recommended to form a team to participate.
|
18 |
+
- Each team can have one or more members.
|
19 |
+
- An individual/team can compete on both task.
|
20 |
+
- An individual associated with multiple teams (for a given task) or a team with multiple accounts will be disqualified.
|
21 |
+
- On both task, a particpant can only use the images as input.
|
22 |
+
- A particpant is allowed to extract its own visual/audio features with any pre-trained model.
|
23 |
+
|
24 |
+
## How to win / What is the prize?
|
25 |
+
|
26 |
+
- For each task, the winner is the individual/team who reach the highest performance on the **challenge** set.
|
27 |
+
- The metrics taken into consideration is the combined metric: completeness * Accuracy@5.
|
28 |
+
- The deadline to submit your results is May 30th at 11.59 pm Pacific Time.
|
29 |
+
- The teams that perform best will be granted $1000 from our sponsor [EVS Broadcast Equipment](https://evs.com/).
|
30 |
+
- In order to be eligible for the prize, we require the individual/team to provide a short report describing the details of the methodology (CVPR format, max 2 pages)
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
## Important dates
|
35 |
+
|
36 |
+
Note that these dates are tentative and subject to changes if necessary.
|
37 |
+
|
38 |
+
- **February 1:** Open evaluation server on the (Open) Test set.
|
39 |
+
- **February 15:** Open evaluation server on the (Seggregated) Challenge set.
|
40 |
+
- **May 30:** Close evaluation server.
|
41 |
+
- **June 6:** Deadline for submitting the report.
|
42 |
+
- **June 19:** A full-day workshop at CVPR 2023.
|
43 |
+
|
44 |
+
For any further doubt or concern, please raise an issue in that repository, or contact us directly on [Discord](https://discord.gg/SM8uHj9mkP).
|
sn_calibration/README.md
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+

|
3 |
+
|
4 |
+
# EVS Camera Calibration Challenge
|
5 |
+
|
6 |
+
|
7 |
+
Welcome to the EVS camera calibration challenge ! This challenge is sponsored by EVS Broadcast Equipment, and is
|
8 |
+
developped in collaboration with the SoccerNet team.
|
9 |
+
|
10 |
+
This challenge consists of two distinct tasks which are defined hereafter. We provide sample code and baselines for each
|
11 |
+
task to help you get started!
|
12 |
+
|
13 |
+
|
14 |
+
Participate in our upcoming Challenge at the [CVSports](https://vap.aau.dk/cvsports/) workshop at CVPR and try to win up to 1000$ sponsored by [EVS](https://evs.com/)! All details are available on the [challenge website](https://eval.ai/web/challenges/challenge-page/1537/overview ), or on the [main page](https://www.soccer-net.org/).
|
15 |
+
|
16 |
+
The participation deadline is fixed at the 30th of May 2023. The official rules and guidelines are available on [ChallengeRules.md](ChallengeRules.md).
|
17 |
+
|
18 |
+
<a href="https://www.youtube.com/watch?v=nxywN6X_0oE">
|
19 |
+
<p align="center"><img src="images/Thumbnail.png" width="720"></p>
|
20 |
+
</a>
|
21 |
+
|
22 |
+
### 2023 Leaderboard
|
23 |
+
|
24 |
+
<p><table class="dataframe">
|
25 |
+
<thead>
|
26 |
+
<tr style="text-align: right;">
|
27 |
+
<th style = "background-color: #FFFFFF;font-family: Century Gothic, sans-serif;font-size: medium;color: #305496;text-align: left;border-bottom: 2px solid #305496;padding: 0px 20px 0px 0px;width: auto">Team</th>
|
28 |
+
<th style = "background-color: #FFFFFF;font-family: Century Gothic, sans-serif;font-size: medium;color: #305496;text-align: left;border-bottom: 2px solid #305496;padding: 0px 20px 0px 0px;width: auto">Combined Metric</th>
|
29 |
+
<th style = "background-color: #FFFFFF;font-family: Century Gothic, sans-serif;font-size: medium;color: #305496;text-align: left;border-bottom: 2px solid #305496;padding: 0px 20px 0px 0px;width: auto">Accuracy@5</th>
|
30 |
+
<th style = "background-color: #FFFFFF;font-family: Century Gothic, sans-serif;font-size: medium;color: #305496;text-align: left;border-bottom: 2px solid #305496;padding: 0px 20px 0px 0px;width: auto">Completeness</th>
|
31 |
+
</tr>
|
32 |
+
</thead>
|
33 |
+
<tbody>
|
34 |
+
<tr>
|
35 |
+
<td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">Sportlight</td>
|
36 |
+
<td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">0.55</td>
|
37 |
+
<td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">73.22</td>
|
38 |
+
<td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">75.59</td>
|
39 |
+
</tr>
|
40 |
+
<tr>
|
41 |
+
<td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">Spiideo</td>
|
42 |
+
<td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">0.53</td>
|
43 |
+
<td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">52.95</td>
|
44 |
+
<td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">99.96</td>
|
45 |
+
</tr>
|
46 |
+
<tr>
|
47 |
+
<td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">SAIVA_Calibration</td>
|
48 |
+
<td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">0.53</td>
|
49 |
+
<td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">60.33</td>
|
50 |
+
<td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">87.22</td>
|
51 |
+
</tr>
|
52 |
+
<tr>
|
53 |
+
<td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">BPP</td>
|
54 |
+
<td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">0.5</td>
|
55 |
+
<td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">69.12</td>
|
56 |
+
<td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">72.54</td>
|
57 |
+
</tr>
|
58 |
+
<tr>
|
59 |
+
<td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">ikapetan</td>
|
60 |
+
<td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">0.43</td>
|
61 |
+
<td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">53.78</td>
|
62 |
+
<td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">79.71</td>
|
63 |
+
</tr>
|
64 |
+
<tr>
|
65 |
+
<td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">NASK</td>
|
66 |
+
<td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">0.41</td>
|
67 |
+
<td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">53.01</td>
|
68 |
+
<td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">77.81</td>
|
69 |
+
</tr>
|
70 |
+
<tr>
|
71 |
+
<td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">Mike Azatov and Jonas Theiner</td>
|
72 |
+
<td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">0.41</td>
|
73 |
+
<td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">58.61</td>
|
74 |
+
<td style = "background-color: #D9E1F2;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">69.34</td>
|
75 |
+
</tr>
|
76 |
+
<tr>
|
77 |
+
<td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">Baseline</td>
|
78 |
+
<td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">0.08</td>
|
79 |
+
<td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">13.54</td>
|
80 |
+
<td style = "background-color: white; color: black;font-family: Century Gothic, sans-serif;font-size: medium;text-align: left;padding: 0px 20px 0px 0px;width: auto">61.54</td>
|
81 |
+
</tr>
|
82 |
+
</tbody>
|
83 |
+
</table></p>
|
84 |
+
|
85 |
+
## Table of content
|
86 |
+
|
87 |
+
- Install
|
88 |
+
- Dataset
|
89 |
+
- Soccer pitch annotations
|
90 |
+
- Camera calibration
|
91 |
+
- Definition
|
92 |
+
- Evaluation
|
93 |
+
- Baseline
|
94 |
+
|
95 |
+
## Install
|
96 |
+
|
97 |
+
Download the network weights from google drive https://drive.google.com/file/d/1dbN7LdMV03BR1Eda8n7iKNIyYp9r07sM/view?usp=sharing
|
98 |
+
and place them in [resources](resources).
|
99 |
+
|
100 |
+
With python 3 already installed, the python environment can be installed with the following command:
|
101 |
+
|
102 |
+
```
|
103 |
+
pip install -r requirements.txt
|
104 |
+
```
|
105 |
+
|
106 |
+
## Dataset
|
107 |
+
|
108 |
+
SoccerNet is a dataset containing 400 broadcasted videos of whole soccer games. The dataset can be found
|
109 |
+
here : https://soccer-net.org/
|
110 |
+
The authors give access to template code and a python package to get started with the dataset. All the documentation
|
111 |
+
about these tools can be found here : https://github.com/SilvioGiancola/SoccerNetv2-DevKit
|
112 |
+
|
113 |
+
All the data needed for challenge can be downloaded with these lines :
|
114 |
+
|
115 |
+
```python
|
116 |
+
from SoccerNet.Downloader import SoccerNetDownloader as SNdl
|
117 |
+
soccerNetDownloader = SNdl(LocalDirectory="path/to/SoccerNet")
|
118 |
+
soccerNetDownloader.downloadDataTask(task="calibration-2023", split=["train","valid","test","challenge"])
|
119 |
+
```
|
120 |
+
|
121 |
+
Historically, the dataset was first released for an action spotting task. In its first version, the images corresponding
|
122 |
+
to soccer actions (goals, fouls, etc) were identified. In the following editions, more annotations have been associated
|
123 |
+
to those images. In the last version of the dataset (SoccerNetV3), the extremities of the lines of the soccer pitch
|
124 |
+
markings have been annotated. As a partnership with SoccerNet's team, we use these annotations in a new challenge. The
|
125 |
+
challenge is divided in two tasks, the resolution of the first leading to the second one. The first is a soccer pitch
|
126 |
+
element localisation task which can then be used for the second task which is a camera calibration task.
|
127 |
+
|
128 |
+
**/!\ New** : some annotations have been added: for some images, there are new points annotation along the pitch markings lines.
|
129 |
+
For straight pitch marking lines, you can always assume that the extremities are annotated, and sometimes,
|
130 |
+
if the image has been reannotated, there will be a few extra points along the imaged line.
|
131 |
+
|
132 |
+
|
133 |
+
### Soccer pitch annotations
|
134 |
+
|
135 |
+
Performing camera calibration can be eased by the presence of an object with a known shape in the image. For soccer
|
136 |
+
content, the soccer pitch can be used as a target for the camera calibration because it has a known shape and its
|
137 |
+
dimensions are specified in the International Football Association Board's law of the
|
138 |
+
game (https://digitalhub.fifa.com/m/5371a6dcc42fbb44/original/d6g1medsi8jrrd3e4imp-pdf.pdf).
|
139 |
+
|
140 |
+
Moreover, we define a set of semantic labels for each semantic element of the soccer pitch. We also define the bottom
|
141 |
+
side of the pitch as the one where the main and 16 meters broadcast cameras are installed.
|
142 |
+
|
143 |
+

|
144 |
+
|
145 |
+
1. Big rect. left bottom,
|
146 |
+
2. Big rect. left main,
|
147 |
+
3. Big rect. left top,
|
148 |
+
4. Big rect. right bottom,
|
149 |
+
5. Big rect. right main,
|
150 |
+
6. Big rect. right top,
|
151 |
+
7. Circle central,
|
152 |
+
8. Circle left,
|
153 |
+
9. Circle right,
|
154 |
+
10. Goal left crossbar,
|
155 |
+
11. Goal left post left ,
|
156 |
+
12. Goal left post right,
|
157 |
+
13. Goal right crossbar,
|
158 |
+
14. Goal right post left,
|
159 |
+
15. Goal right post right,
|
160 |
+
16. Middle line,
|
161 |
+
17. Side line bottom,
|
162 |
+
18. Side line left,
|
163 |
+
19. Side line right,
|
164 |
+
20. Side line top,
|
165 |
+
21. Small rect. left bottom,
|
166 |
+
22. Small rect. left main,
|
167 |
+
23. Small rect. left top,
|
168 |
+
24. Small rect. right bottom,
|
169 |
+
25. Small rect. right main,
|
170 |
+
26. Small rect. right top
|
171 |
+
|
172 |
+
In the third version of SoccerNet, there are new annotations for each image of the dataset. These new annotations
|
173 |
+
consists in the list of all extremities of the semantic elements of the pitch present in the image. The extremities are
|
174 |
+
a pair of 2D point coordinates.
|
175 |
+
|
176 |
+
For the circles drawn on the pitch, the annotations consist in a list of points which give roughly the shape of the
|
177 |
+
circle when connected. Note that due to new annotations, the sequential order of circle points is no longer guaranteed.
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
## Camera calibration task
|
182 |
+
|
183 |
+
As a soccer pitch has known dimensions, it is possible to use the soccer pitch as a calibration target in order to
|
184 |
+
calibrate the camera. Since the pitch is roughly planar, we can model the transformation applied by the camera to the pitch by a
|
185 |
+
homography. In order to estimate the homography between the pitch and its image, we only need 4 lines. We provide a
|
186 |
+
baseline in order to extract camera parameters from this kind of images.
|
187 |
+
|
188 |
+
### Definition
|
189 |
+
|
190 |
+
In this task, we ask you to provide valid camera parameters for each image of the challenge set.
|
191 |
+
|
192 |
+
Given a common 3D pitch template, we will use the camera parameters produced by your algorithm in order to estimate the
|
193 |
+
reprojection error induced by the camera parameters. The camera parameters include its lens parameters, its orientation,
|
194 |
+
its translation with respect to the world reference axis system that we define accordingly:
|
195 |
+
|
196 |
+

|
197 |
+
|
198 |
+
#### Rotation convention
|
199 |
+
|
200 |
+
Following Euler angles convention, we use a ZXZ succession of intrinsic rotations in order to describe the orientation
|
201 |
+
of the camera. Starting from the world reference axis system, we first apply a rotation around the Z axis to pan the
|
202 |
+
camera. Then the obtained axis system is rotated around its x axis in order to tilt the camera. Then the last rotation
|
203 |
+
around the z axis of the new axis system allows to roll the camera. Note that this z axis is the principal axis of the
|
204 |
+
camera.
|
205 |
+
|
206 |
+
The lens parameters produced must follow the pinhole model. Additionally, the parameters can include radial, tangential
|
207 |
+
and thin prism distortion parameters. This corresponds to the full model of
|
208 |
+
OpenCV : https://docs.opencv.org/4.5.0/d9/d0c/group__calib3d.html#details
|
209 |
+
|
210 |
+
For each image of the test set, we expect to receive a json file named "**camera_{frame_index}.json**" containing a
|
211 |
+
dictionary with the camera parameters.
|
212 |
+
|
213 |
+
```
|
214 |
+
# camera_00001.json
|
215 |
+
{
|
216 |
+
"pan_degrees": 14.862476218376278,
|
217 |
+
"tilt_degrees": 78.83988009048775,
|
218 |
+
"roll_degrees": -2.2210919345134497,
|
219 |
+
"position_meters": [
|
220 |
+
32.6100008989567,
|
221 |
+
67.9363036953344,
|
222 |
+
-14.898134157887508
|
223 |
+
],
|
224 |
+
"x_focal_length": 3921.6013418112757,
|
225 |
+
"y_focal_length": 3921.601341812138,
|
226 |
+
"principal_point": [
|
227 |
+
480.0,
|
228 |
+
270.0
|
229 |
+
],
|
230 |
+
"radial_distortion": [
|
231 |
+
0.0,
|
232 |
+
0.0,
|
233 |
+
0.0,
|
234 |
+
0.0,
|
235 |
+
0.0,
|
236 |
+
0.0
|
237 |
+
],
|
238 |
+
"tangential_distortion": [
|
239 |
+
0.0,
|
240 |
+
0.0
|
241 |
+
],
|
242 |
+
"thin_prism_distortion": [
|
243 |
+
0.0,
|
244 |
+
0.0,
|
245 |
+
0.0,
|
246 |
+
0.0
|
247 |
+
]
|
248 |
+
}
|
249 |
+
```
|
250 |
+
|
251 |
+
The results should be organized as follows :
|
252 |
+
|
253 |
+
|
254 |
+
```
|
255 |
+
test.zip
|
256 |
+
|__ camera_00001.json
|
257 |
+
|__ camera_00002.json
|
258 |
+
```
|
259 |
+
|
260 |
+
|
261 |
+
### Evaluation
|
262 |
+
|
263 |
+
We only evaluate the quality of the camera parameters provided in the image world, as it is the only groundtruth that we
|
264 |
+
have. The annotated points are either extremities or points on circles, and thus they do not always have a mapping to a
|
265 |
+
3D point, but we can always map them to lines or circles.
|
266 |
+
|
267 |
+
#### Dealing with ambiguities
|
268 |
+
|
269 |
+
The dataset contains some ambiguities when we consider each image independently. Without context, one can't know if a
|
270 |
+
camera behind a goal is on the left or the right side of the pitch.
|
271 |
+
|
272 |
+
For example, this image taken by a camera behind the goal :
|
273 |
+
|
274 |
+

|
275 |
+
|
276 |
+
It is impossible to say without any context if it was shot by the camera whose filming range is in blue or in yellow.
|
277 |
+
|
278 |
+

|
279 |
+
|
280 |
+
Therefore, we take this ambiguity into account in the evaluation and we consider both accuracy for the groundtruth
|
281 |
+
label, and the accuracy for the central projection by the pitch center of the labels. The higher accuracy will be
|
282 |
+
selected for your evaluation.
|
283 |
+
|
284 |
+
|
285 |
+
We evaluate the best submission based on the accuracy at a specific threshold distance. This metric is explained
|
286 |
+
hereunder.
|
287 |
+
|
288 |
+
#### Accuracy @ threshold
|
289 |
+
|
290 |
+
The evaluation is based on the reprojection error which we define here as the L2 distance between one annotated point
|
291 |
+
and the line to which the point belong. This metric does not account well for false positives and false negatives
|
292 |
+
(hallucinated/missing lines projections). Thus we formulate our evaluation as a binary classification, with a
|
293 |
+
distance threshold with a twist : this time, we consider a pitch marking to be one entity, and for it to be correctly
|
294 |
+
detected, all its extremities (or all points annotated for circles) must have a reprojection error smaller than the
|
295 |
+
threshold.
|
296 |
+
|
297 |
+
As we allow lens distortion, the projection of the pitch line markings can be curvated. This is why we sample the pitch
|
298 |
+
model every few centimeters, and we consider that the distance between a projected pitch marking and a groundtruth point is in
|
299 |
+
fact the euclidian distance between the groundtruth point and the polyline given by the projection of sampled points.
|
300 |
+
|
301 |
+
* True positives : for classes that belong both to the prediction and the groundtruth, a predicted element is a True
|
302 |
+
Positive if all the L2 distances between its groundtruth points and the predicted polyline are lower than a
|
303 |
+
certain threshold.
|
304 |
+
|
305 |
+

|
306 |
+
|
307 |
+
* False positives : contains elements that were detected with a class that do not belong to the groundtruth classes, and
|
308 |
+
elements with valid classes which are distant from at least **t** pixels from one of the groundtruth points associated
|
309 |
+
to the element.
|
310 |
+
* False negatives: Line elements only present in the groundtruth are counted as False Negatives.
|
311 |
+
* True negatives : There are no True Negatives.
|
312 |
+
|
313 |
+
The Accuracy for a threshold of t pixels is given by : **Acc@t = TP/(TP+FN+FP)**. We evaluate the accuracy at 5 pixels.
|
314 |
+
We only use images with predicted camera parameters in this evaluation.
|
315 |
+
|
316 |
+
#### Completeness rate
|
317 |
+
|
318 |
+
We also measure the completeness rate as the number of camera parameters provided divided by the number of images with
|
319 |
+
more than four semantic line annotations in the dataset.
|
320 |
+
|
321 |
+
#### Final score
|
322 |
+
|
323 |
+
The evaluation criterion for a camera calibration method is the following : **Completeness x Acc@5**
|
324 |
+
|
325 |
+
#### Per class information
|
326 |
+
|
327 |
+
The global accuracy described above has the advantage to treat all soccer pitch marking types equivalently even if the
|
328 |
+
groundtruth contains more annotated points for a specific class. For instance, a circle is annotated with 9 points on
|
329 |
+
average, whilst rectilinear elements have usually two points annotated. But this metric might be harder as we consider
|
330 |
+
all points annotated instead of each of them independently. This is precisely why we propose this per class metric, that
|
331 |
+
accounts for each point annotated separately. This metric is only for information purposes and will not be used in the
|
332 |
+
ranking of submissions.
|
333 |
+
|
334 |
+
The prediction is obtained by sampling each 3D real world pitch element every few centimeters, which means that the
|
335 |
+
number of points in the prediction may be variable for the same camera parameters. This gives a very high number of
|
336 |
+
predicted points for a certain class, and thus we find a workaround to count false positives.
|
337 |
+
|
338 |
+
The confusion matrices are computed per class in the following way :
|
339 |
+
|
340 |
+
* True positives : for classes that belong both to the prediction and the groundtruth, a predicted point is counted in
|
341 |
+
the True Positives if the L2 distance from this groundtruth point to the predicted polyline is lower than a
|
342 |
+
certain threshold **t**.
|
343 |
+
|
344 |
+
* False positives : counts groundtruth points that have a distance to the corresponding predicted polyline that is higher
|
345 |
+
than the threshold value **t**. For predicted lines that do not belong to the groundtruth, we can not count every
|
346 |
+
predicted point as a false positive because the number of points depend on the sampling factor that has to be high
|
347 |
+
enough, which can lower a lot our metric. We decided arbitrarily to count for this kind of false positive the number
|
348 |
+
of points that a human annotator would have annotated for this class, i.e. two points for rectilinear elements and 9
|
349 |
+
for circle elements.
|
350 |
+
* False negatives: All groundtruth points whose class is only present in the groundtruth are counted as False Negatives.
|
351 |
+
* True negatives : There are no True Negatives.
|
352 |
+
|
353 |
+
The Accuracy for a threshold of t pixels is given by : **Acc@t = TP/(TP+FN+FP)**. We evaluate the accuracy at 5, 10 and
|
354 |
+
20 pixels. We only use images with predicted camera parameters in this evaluation.
|
355 |
+
|
356 |
+
|
357 |
+
### Baseline
|
358 |
+
|
359 |
+
#### Method
|
360 |
+
|
361 |
+
For our camera calibration baseline, we proceed in two steps: first we find the pitch markings location in the image,
|
362 |
+
and then given our pitch marking correspondences, we estimate camera parameters.
|
363 |
+
|
364 |
+
For this first step, we decided to locate the pitch markings with a neural network trained to perform semantic line
|
365 |
+
segmentation. We used DeepLabv3 architecture. The target semantic segmentation masks were generated by joining
|
366 |
+
successively all the points annotated for each line in the image. We provide the dataloader that we used for the
|
367 |
+
training in the src folder.
|
368 |
+
|
369 |
+
The segmentation maps predicted by the neural
|
370 |
+
network are further processed in order to get the line extremities. First, for each class, we fit circles on the
|
371 |
+
segmentation mask. All pixels belonging to a same class are thus synthesized by a set of points (i.e. circles centers).
|
372 |
+
Then we build polylines based on the circles centers : all the points that are close enough are considered to belong to
|
373 |
+
the same polyline. Finally the extremities of each line class will be the extremities of the longest polyline for that
|
374 |
+
line class.
|
375 |
+
|
376 |
+
You can test the line detection with the following code:
|
377 |
+
|
378 |
+
```
|
379 |
+
python src/detect_extremities.py -s <path_to_soccernet_dataset> -p <path_to_store_predictions>
|
380 |
+
```
|
381 |
+
|
382 |
+
|
383 |
+
In the second step, we use the extremities of the lines (
|
384 |
+
not Circles) detected in the image in order to estimate an homography from the soccer pitch model to the image. We
|
385 |
+
provide a class **soccerpitch.py** to define the pitch model according to the rules of the game. The homography is then
|
386 |
+
decomposed in camera parameters. All aspects concerning camera parameters are located in the **camera.py**, including
|
387 |
+
homography decomposition in rotation and translation matrices, including calibration matrix estimation from the
|
388 |
+
homography, functions for projection...
|
389 |
+
|
390 |
+
You can test the baseline with the following line :
|
391 |
+
|
392 |
+
```python src/baseline_cameras.py -s <path to soccernet dataset> -p <path to 1st task prediction>```
|
393 |
+
|
394 |
+
And to test the evaluation, you can run :
|
395 |
+
|
396 |
+
`python src/evaluate_camera.py -s <path to soccernet dataset> -p <path to predictions> -t <threshold value>`
|
397 |
+
|
398 |
+
|
399 |
+
#### Results
|
400 |
+
|
401 |
+
| Acc@t | Acc@5 | Completeness | Final score |
|
402 |
+
|----------|-------|--------------|-------------|
|
403 |
+
| Baseline | 11.7% | 68% | 7.96% |
|
404 |
+
|
405 |
+
#### Improvements
|
406 |
+
|
407 |
+
The baseline could be directly improved by :
|
408 |
+
|
409 |
+
* exploiting the masks rather than the extremities prediction of the first baseline
|
410 |
+
* using ransac to estimate the homography
|
411 |
+
* refining the camera parameters using line and ellipses correspondences.
|
412 |
+
|
413 |
+
## Citation
|
414 |
+
For further information check out the paper and supplementary material:
|
415 |
+
https://arxiv.org/abs/2210.02365
|
416 |
+
|
417 |
+
Please cite our work if you use the SoccerNet dataset:
|
418 |
+
```bibtex
|
419 |
+
@inproceedings{Giancola_2022,
|
420 |
+
doi = {10.1145/3552437.3558545},
|
421 |
+
url = {https://doi.org/10.1145%2F3552437.3558545},
|
422 |
+
year = 2022,
|
423 |
+
month = {oct},
|
424 |
+
publisher = {{ACM}},
|
425 |
+
author = {Silvio Giancola and Anthony Cioppa and Adrien Deli{\`{e}}ge and Floriane Magera and Vladimir Somers and Le Kang and Xin Zhou and Olivier Barnich and Christophe De Vleeschouwer and Alexandre Alahi and Bernard Ghanem and Marc Van Droogenbroeck and Abdulrahman Darwish and Adrien Maglo and Albert Clap{\'{e}}s and Andreas Luyts and Andrei Boiarov and Artur Xarles and Astrid Orcesi and Avijit Shah and Baoyu Fan and Bharath Comandur and Chen Chen and Chen Zhang and Chen Zhao and Chengzhi Lin and Cheuk-Yiu Chan and Chun Chuen Hui and Dengjie Li and Fan Yang and Fan Liang and Fang Da and Feng Yan and Fufu Yu and Guanshuo Wang and H. Anthony Chan and He Zhu and Hongwei Kan and Jiaming Chu and Jianming Hu and Jianyang Gu and Jin Chen and Jo{\~{a}}o V. B. Soares and Jonas Theiner and Jorge De Corte and Jos{\'{e}} Henrique Brito and Jun Zhang and Junjie Li and Junwei Liang and Leqi Shen and Lin Ma and Lingchi Chen and Miguel Santos Marques and Mike Azatov and Nikita Kasatkin and Ning Wang and Qiong Jia and Quoc Cuong Pham and Ralph Ewerth and Ran Song and Rengang Li and Rikke Gade and Ruben Debien and Runze Zhang and Sangrok Lee and Sergio Escalera and Shan Jiang and Shigeyuki Odashima and Shimin Chen and Shoichi Masui and Shouhong Ding and Sin-wai Chan and Siyu Chen and Tallal El-Shabrawy and Tao He and Thomas B. Moeslund and Wan-Chi Siu and Wei Zhang and Wei Li and Xiangwei Wang and Xiao Tan and Xiaochuan Li and Xiaolin Wei and Xiaoqing Ye and Xing Liu and Xinying Wang and Yandong Guo and Yaqian Zhao and Yi Yu and Yingying Li and Yue He and Yujie Zhong and Zhenhua Guo and Zhiheng Li},
|
426 |
+
title = {{SoccerNet} 2022 Challenges Results},
|
427 |
+
booktitle = {Proceedings of the 5th International {ACM} Workshop on Multimedia Content Analysis in Sports}
|
428 |
+
}
|
429 |
+
```
|
430 |
+
|
431 |
+
|
432 |
+
|
433 |
+
## Our other Challenges
|
434 |
+
|
435 |
+
Check out our other challenges related to SoccerNet!
|
436 |
+
- [Action Spotting](https://github.com/SoccerNet/sn-spotting)
|
437 |
+
- [Replay Grounding](https://github.com/SoccerNet/sn-grounding)
|
438 |
+
- [Calibration](https://github.com/SoccerNet/sn-calibration)
|
439 |
+
- [Re-Identification](https://github.com/SoccerNet/sn-reid)
|
440 |
+
- [Tracking](https://github.com/SoccerNet/sn-tracking)
|
sn_calibration/requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# numpy~>1.21
|
2 |
+
# torch~=1.10.0
|
3 |
+
# torchvision~=0.11.1
|
4 |
+
# pillow~>9.0.0
|
5 |
+
# tqdm~=4.62.3
|
6 |
+
# SoccerNet>=0.1.23
|
7 |
+
# opencv-python~=4.5.5
|
8 |
+
# matplotlib~=3.5.1
|
9 |
+
|
10 |
+
numpy>=1.21
|
11 |
+
torch>=2.0.0
|
12 |
+
torchvision>=0.15.0
|
13 |
+
pillow>=9.0.0
|
14 |
+
tqdm~=4.62.3
|
15 |
+
SoccerNet>=0.1.23
|
16 |
+
opencv-python~=4.5.5
|
17 |
+
matplotlib~=3.5.1
|
sn_calibration/resources/mean.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8dcd1f96b2486853498b6e38a0eb344cda9c272cd0fe87adb84ec0a09e244a36
|
3 |
+
size 152
|
sn_calibration/resources/std.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:49dac1cfa5ab2f197a64332c4fd7dc08335dd46d31244acc0e4a03686241fe79
|
3 |
+
size 152
|
sn_calibration/src/baseline_cameras.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import cv2 as cv
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from src.camera import Camera
|
10 |
+
from src.soccerpitch import SoccerPitch
|
11 |
+
|
12 |
+
|
13 |
+
def normalization_transform(points):
|
14 |
+
"""
|
15 |
+
Computes the similarity transform such that the list of points is centered around (0,0) and that its distance to the
|
16 |
+
center is sqrt(2).
|
17 |
+
:param points: point cloud that we wish to normalize
|
18 |
+
:return: the affine transformation matrix
|
19 |
+
"""
|
20 |
+
center = np.mean(points, axis=0)
|
21 |
+
|
22 |
+
d = 0.
|
23 |
+
nelems = 0
|
24 |
+
for p in points:
|
25 |
+
nelems += 1
|
26 |
+
x = p[0] - center[0]
|
27 |
+
y = p[1] - center[1]
|
28 |
+
di = np.sqrt(x ** 2 + y ** 2)
|
29 |
+
d += (di - d) / nelems
|
30 |
+
|
31 |
+
if d <= 0.:
|
32 |
+
s = 1.
|
33 |
+
else:
|
34 |
+
s = np.sqrt(2) / d
|
35 |
+
T = np.zeros((3, 3))
|
36 |
+
T[0, 0] = s
|
37 |
+
T[0, 2] = -s * center[0]
|
38 |
+
T[1, 1] = s
|
39 |
+
T[1, 2] = -s * center[1]
|
40 |
+
T[2, 2] = 1
|
41 |
+
return T
|
42 |
+
|
43 |
+
|
44 |
+
def estimate_homography_from_line_correspondences(lines, T1=np.eye(3), T2=np.eye(3)):
|
45 |
+
"""
|
46 |
+
Given lines correspondences, computes the homography that maps best the two set of lines.
|
47 |
+
:param lines: list of pair of 2D lines matches.
|
48 |
+
:param T1: Similarity transform to normalize the elements of the source reference system
|
49 |
+
:param T2: Similarity transform to normalize the elements of the target reference system
|
50 |
+
:return: boolean to indicate success or failure of the estimation, homography
|
51 |
+
"""
|
52 |
+
homography = np.eye(3)
|
53 |
+
A = np.zeros((len(lines) * 2, 9))
|
54 |
+
|
55 |
+
for i, line_pair in enumerate(lines):
|
56 |
+
src_line = np.transpose(np.linalg.inv(T1)) @ line_pair[0]
|
57 |
+
target_line = np.transpose(np.linalg.inv(T2)) @ line_pair[1]
|
58 |
+
u = src_line[0]
|
59 |
+
v = src_line[1]
|
60 |
+
w = src_line[2]
|
61 |
+
|
62 |
+
x = target_line[0]
|
63 |
+
y = target_line[1]
|
64 |
+
z = target_line[2]
|
65 |
+
|
66 |
+
A[2 * i, 0] = 0
|
67 |
+
A[2 * i, 1] = x * w
|
68 |
+
A[2 * i, 2] = -x * v
|
69 |
+
A[2 * i, 3] = 0
|
70 |
+
A[2 * i, 4] = y * w
|
71 |
+
A[2 * i, 5] = -v * y
|
72 |
+
A[2 * i, 6] = 0
|
73 |
+
A[2 * i, 7] = z * w
|
74 |
+
A[2 * i, 8] = -v * z
|
75 |
+
|
76 |
+
A[2 * i + 1, 0] = x * w
|
77 |
+
A[2 * i + 1, 1] = 0
|
78 |
+
A[2 * i + 1, 2] = -x * u
|
79 |
+
A[2 * i + 1, 3] = y * w
|
80 |
+
A[2 * i + 1, 4] = 0
|
81 |
+
A[2 * i + 1, 5] = -u * y
|
82 |
+
A[2 * i + 1, 6] = z * w
|
83 |
+
A[2 * i + 1, 7] = 0
|
84 |
+
A[2 * i + 1, 8] = -u * z
|
85 |
+
|
86 |
+
try:
|
87 |
+
u, s, vh = np.linalg.svd(A)
|
88 |
+
except np.linalg.LinAlgError:
|
89 |
+
return False, homography
|
90 |
+
v = np.eye(3)
|
91 |
+
has_positive_singular_value = False
|
92 |
+
for i in range(s.shape[0] - 1, -2, -1):
|
93 |
+
v = np.reshape(vh[i], (3, 3))
|
94 |
+
|
95 |
+
if s[i] > 0:
|
96 |
+
has_positive_singular_value = True
|
97 |
+
break
|
98 |
+
|
99 |
+
if not has_positive_singular_value:
|
100 |
+
return False, homography
|
101 |
+
|
102 |
+
homography = np.reshape(v, (3, 3))
|
103 |
+
homography = np.linalg.inv(T2) @ homography @ T1
|
104 |
+
homography /= homography[2, 2]
|
105 |
+
|
106 |
+
return True, homography
|
107 |
+
|
108 |
+
|
109 |
+
def draw_pitch_homography(image, homography):
|
110 |
+
"""
|
111 |
+
Draws points along the soccer pitch markings elements in the image based on the homography projection.
|
112 |
+
/!\ This function assumes that the resolution of the image is 540p.
|
113 |
+
:param image
|
114 |
+
:param homography: homography that captures the relation between the world pitch plane and the image
|
115 |
+
:return: modified image
|
116 |
+
"""
|
117 |
+
field = SoccerPitch()
|
118 |
+
polylines = field.sample_field_points()
|
119 |
+
for line in polylines.values():
|
120 |
+
|
121 |
+
for point in line:
|
122 |
+
if point[2] == 0.:
|
123 |
+
hp = np.array((point[0], point[1], 1.))
|
124 |
+
projected = homography @ hp
|
125 |
+
if projected[2] == 0.:
|
126 |
+
continue
|
127 |
+
projected /= projected[2]
|
128 |
+
if 0 < projected[0] < 960 and 0 < projected[1] < 540:
|
129 |
+
cv.circle(image, (int(projected[0]), int(projected[1])), 1, (255, 0, 0), 1)
|
130 |
+
|
131 |
+
return image
|
132 |
+
|
133 |
+
|
134 |
+
if __name__ == "__main__":
|
135 |
+
|
136 |
+
parser = argparse.ArgumentParser(description='Baseline for camera parameters extraction')
|
137 |
+
|
138 |
+
parser.add_argument('-s', '--soccernet', default="/home/fmg/data/SN23/calibration-2023-bis/", type=str,
|
139 |
+
help='Path to the SoccerNet-V3 dataset folder')
|
140 |
+
parser.add_argument('-p', '--prediction', default="/home/fmg/results/SN23-tests/",
|
141 |
+
required=False, type=str,
|
142 |
+
help="Path to the prediction folder")
|
143 |
+
parser.add_argument('--split', required=False, type=str, default="valid", help='Select the split of data')
|
144 |
+
parser.add_argument('--resolution_width', required=False, type=int, default=960,
|
145 |
+
help='width resolution of the images')
|
146 |
+
parser.add_argument('--resolution_height', required=False, type=int, default=540,
|
147 |
+
help='height resolution of the images')
|
148 |
+
args = parser.parse_args()
|
149 |
+
|
150 |
+
field = SoccerPitch()
|
151 |
+
|
152 |
+
dataset_dir = os.path.join(args.soccernet, args.split)
|
153 |
+
if not os.path.exists(dataset_dir):
|
154 |
+
print("Invalid dataset path !")
|
155 |
+
exit(-1)
|
156 |
+
|
157 |
+
with open(os.path.join(dataset_dir, "per_match_info.json"), 'r') as f:
|
158 |
+
match_info = json.load(f)
|
159 |
+
|
160 |
+
with tqdm(enumerate(match_info.keys()), total=len(match_info.keys()), ncols=160) as t:
|
161 |
+
for i, match in t:
|
162 |
+
frame_list = match_info[match].keys()
|
163 |
+
|
164 |
+
for frame in frame_list:
|
165 |
+
frame_index = frame.split(".")[0]
|
166 |
+
prediction_file = os.path.join(args.prediction, args.split, f"extremities_{frame_index}.json")
|
167 |
+
|
168 |
+
if not os.path.exists(prediction_file):
|
169 |
+
continue
|
170 |
+
|
171 |
+
with open(prediction_file, 'r') as f:
|
172 |
+
predictions = json.load(f)
|
173 |
+
|
174 |
+
camera_predictions = dict()
|
175 |
+
image_path = os.path.join(dataset_dir, frame)
|
176 |
+
# cv_image = cv.imread(image_path)
|
177 |
+
# cv_image = cv.resize(cv_image, (args.resolution_width, args.resolution_height))
|
178 |
+
|
179 |
+
line_matches = []
|
180 |
+
potential_3d_2d_matches = {}
|
181 |
+
src_pts = []
|
182 |
+
success = False
|
183 |
+
for k, v in predictions.items():
|
184 |
+
if k == 'Circle central' or "unknown" in k:
|
185 |
+
continue
|
186 |
+
P3D1 = field.line_extremities_keys[k][0]
|
187 |
+
P3D2 = field.line_extremities_keys[k][1]
|
188 |
+
p1 = np.array([v[0]['x'] * args.resolution_width, v[0]['y'] * args.resolution_height, 1.])
|
189 |
+
p2 = np.array([v[1]['x'] * args.resolution_width, v[1]['y'] * args.resolution_height, 1.])
|
190 |
+
src_pts.extend([p1, p2])
|
191 |
+
if P3D1 in potential_3d_2d_matches.keys():
|
192 |
+
potential_3d_2d_matches[P3D1].extend([p1, p2])
|
193 |
+
else:
|
194 |
+
potential_3d_2d_matches[P3D1] = [p1, p2]
|
195 |
+
if P3D2 in potential_3d_2d_matches.keys():
|
196 |
+
potential_3d_2d_matches[P3D2].extend([p1, p2])
|
197 |
+
else:
|
198 |
+
potential_3d_2d_matches[P3D2] = [p1, p2]
|
199 |
+
|
200 |
+
start = (int(p1[0]), int(p1[1]))
|
201 |
+
end = (int(p2[0]), int(p2[1]))
|
202 |
+
# cv.line(cv_image, start, end, (0, 0, 255), 1)
|
203 |
+
|
204 |
+
line = np.cross(p1, p2)
|
205 |
+
if np.isnan(np.sum(line)) or np.isinf(np.sum(line)):
|
206 |
+
continue
|
207 |
+
line_pitch = field.get_2d_homogeneous_line(k)
|
208 |
+
if line_pitch is not None:
|
209 |
+
line_matches.append((line_pitch, line))
|
210 |
+
|
211 |
+
if len(line_matches) >= 4:
|
212 |
+
target_pts = [field.point_dict[k][:2] for k in potential_3d_2d_matches.keys()]
|
213 |
+
T1 = normalization_transform(target_pts)
|
214 |
+
T2 = normalization_transform(src_pts)
|
215 |
+
success, homography = estimate_homography_from_line_correspondences(line_matches, T1, T2)
|
216 |
+
if success:
|
217 |
+
# cv_image = draw_pitch_homography(cv_image, homography)
|
218 |
+
|
219 |
+
cam = Camera(args.resolution_width, args.resolution_height)
|
220 |
+
success = cam.from_homography(homography)
|
221 |
+
if success:
|
222 |
+
point_matches = []
|
223 |
+
added_pts = set()
|
224 |
+
for k, potential_matches in potential_3d_2d_matches.items():
|
225 |
+
p3D = field.point_dict[k]
|
226 |
+
projected = cam.project_point(p3D)
|
227 |
+
|
228 |
+
if 0 < projected[0] < args.resolution_width and 0 < projected[
|
229 |
+
1] < args.resolution_height:
|
230 |
+
dist = np.zeros(len(potential_matches))
|
231 |
+
for i, potential_match in enumerate(potential_matches):
|
232 |
+
dist[i] = np.sqrt((projected[0] - potential_match[0]) ** 2 + (
|
233 |
+
projected[1] - potential_match[1]) ** 2)
|
234 |
+
selected = np.argmin(dist)
|
235 |
+
if dist[selected] < 100:
|
236 |
+
point_matches.append((p3D, potential_matches[selected][:2]))
|
237 |
+
|
238 |
+
if len(point_matches) > 3:
|
239 |
+
cam.refine_camera(point_matches)
|
240 |
+
# cam.draw_colorful_pitch(cv_image, SoccerField.palette)
|
241 |
+
# print(image_path)
|
242 |
+
# cv.imshow("colorful pitch", cv_image)
|
243 |
+
# cv.waitKey(0)
|
244 |
+
|
245 |
+
if success:
|
246 |
+
camera_predictions = cam.to_json_parameters()
|
247 |
+
|
248 |
+
task2_prediction_file = os.path.join(args.prediction, args.split, f"camera_{frame_index}.json")
|
249 |
+
if camera_predictions:
|
250 |
+
with open(task2_prediction_file, "w") as f:
|
251 |
+
json.dump(camera_predictions, f, indent=4)
|
sn_calibration/src/camera.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2 as cv
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from soccerpitch import SoccerPitch
|
5 |
+
|
6 |
+
|
7 |
+
def pan_tilt_roll_to_orientation(pan, tilt, roll):
|
8 |
+
"""
|
9 |
+
Conversion from euler angles to orientation matrix.
|
10 |
+
:param pan:
|
11 |
+
:param tilt:
|
12 |
+
:param roll:
|
13 |
+
:return: orientation matrix
|
14 |
+
"""
|
15 |
+
Rpan = np.array([
|
16 |
+
[np.cos(pan), -np.sin(pan), 0],
|
17 |
+
[np.sin(pan), np.cos(pan), 0],
|
18 |
+
[0, 0, 1]])
|
19 |
+
Rroll = np.array([
|
20 |
+
[np.cos(roll), -np.sin(roll), 0],
|
21 |
+
[np.sin(roll), np.cos(roll), 0],
|
22 |
+
[0, 0, 1]])
|
23 |
+
Rtilt = np.array([
|
24 |
+
[1, 0, 0],
|
25 |
+
[0, np.cos(tilt), -np.sin(tilt)],
|
26 |
+
[0, np.sin(tilt), np.cos(tilt)]])
|
27 |
+
rotMat = np.dot(Rpan, np.dot(Rtilt, Rroll))
|
28 |
+
return rotMat
|
29 |
+
|
30 |
+
|
31 |
+
def rotation_matrix_to_pan_tilt_roll(rotation):
|
32 |
+
"""
|
33 |
+
Decomposes the rotation matrix into pan, tilt and roll angles. There are two solutions, but as we know that cameramen
|
34 |
+
try to minimize roll, we take the solution with the smallest roll.
|
35 |
+
:param rotation: rotation matrix
|
36 |
+
:return: pan, tilt and roll in radians
|
37 |
+
"""
|
38 |
+
orientation = np.transpose(rotation)
|
39 |
+
first_tilt = np.arccos(orientation[2, 2])
|
40 |
+
second_tilt = - first_tilt
|
41 |
+
|
42 |
+
sign_first_tilt = 1. if np.sin(first_tilt) > 0. else -1.
|
43 |
+
sign_second_tilt = 1. if np.sin(second_tilt) > 0. else -1.
|
44 |
+
|
45 |
+
first_pan = np.arctan2(sign_first_tilt * orientation[0, 2], sign_first_tilt * - orientation[1, 2])
|
46 |
+
second_pan = np.arctan2(sign_second_tilt * orientation[0, 2], sign_second_tilt * - orientation[1, 2])
|
47 |
+
first_roll = np.arctan2(sign_first_tilt * orientation[2, 0], sign_first_tilt * orientation[2, 1])
|
48 |
+
second_roll = np.arctan2(sign_second_tilt * orientation[2, 0], sign_second_tilt * orientation[2, 1])
|
49 |
+
|
50 |
+
# print(f"first solution {first_pan*180./np.pi}, {first_tilt*180./np.pi}, {first_roll*180./np.pi}")
|
51 |
+
# print(f"second solution {second_pan*180./np.pi}, {second_tilt*180./np.pi}, {second_roll*180./np.pi}")
|
52 |
+
if np.fabs(first_roll) < np.fabs(second_roll):
|
53 |
+
return first_pan, first_tilt, first_roll
|
54 |
+
return second_pan, second_tilt, second_roll
|
55 |
+
|
56 |
+
|
57 |
+
def unproject_image_point(homography, point2D):
|
58 |
+
"""
|
59 |
+
Given the homography from the world plane of the pitch and the image and a point localized on the pitch plane in the
|
60 |
+
image, returns the coordinates of the point in the 3D pitch plane.
|
61 |
+
/!\ Only works for correspondences on the pitch (Z = 0).
|
62 |
+
:param homography: the homography
|
63 |
+
:param point2D: the image point whose relative coordinates on the world plane of the pitch are to be found
|
64 |
+
:return: A 2D point on the world pitch plane in homogenous coordinates (X,Y,1) with X and Y being the world
|
65 |
+
coordinates of the point.
|
66 |
+
"""
|
67 |
+
hinv = np.linalg.inv(homography)
|
68 |
+
pitchpoint = hinv @ point2D
|
69 |
+
pitchpoint = pitchpoint / pitchpoint[2]
|
70 |
+
return pitchpoint
|
71 |
+
|
72 |
+
|
73 |
+
class Camera:
|
74 |
+
|
75 |
+
def __init__(self, iwidth=960, iheight=540):
|
76 |
+
self.position = np.zeros(3)
|
77 |
+
self.rotation = np.eye(3)
|
78 |
+
self.calibration = np.eye(3)
|
79 |
+
self.radial_distortion = np.zeros(6)
|
80 |
+
self.thin_prism_disto = np.zeros(4)
|
81 |
+
self.tangential_disto = np.zeros(2)
|
82 |
+
self.image_width = iwidth
|
83 |
+
self.image_height = iheight
|
84 |
+
self.xfocal_length = 1
|
85 |
+
self.yfocal_length = 1
|
86 |
+
self.principal_point = (self.image_width / 2, self.image_height / 2)
|
87 |
+
|
88 |
+
def solve_pnp(self, point_matches):
|
89 |
+
"""
|
90 |
+
With a known calibration matrix, this method can be used in order to retrieve rotation and translation camera
|
91 |
+
parameters.
|
92 |
+
:param point_matches: A list of pairs of 3D-2D point matches .
|
93 |
+
"""
|
94 |
+
target_pts = np.array([pt[0] for pt in point_matches])
|
95 |
+
src_pts = np.array([pt[1] for pt in point_matches])
|
96 |
+
_, rvec, t, inliers = cv.solvePnPRansac(target_pts, src_pts, self.calibration, None)
|
97 |
+
self.rotation, _ = cv.Rodrigues(rvec)
|
98 |
+
self.position = - np.transpose(self.rotation) @ t.flatten()
|
99 |
+
|
100 |
+
def refine_camera(self, pointMatches):
|
101 |
+
"""
|
102 |
+
Once that there is a minimal set of initial camera parameters (calibration, rotation and position roughly known),
|
103 |
+
this method can be used to refine the solution using a non-linear optimization procedure.
|
104 |
+
:param pointMatches: A list of pairs of 3D-2D point matches .
|
105 |
+
|
106 |
+
"""
|
107 |
+
rvec, _ = cv.Rodrigues(self.rotation)
|
108 |
+
target_pts = np.array([pt[0] for pt in pointMatches])
|
109 |
+
src_pts = np.array([pt[1] for pt in pointMatches])
|
110 |
+
|
111 |
+
rvec, t = cv.solvePnPRefineLM(target_pts, src_pts, self.calibration, None, rvec, -self.rotation @ self.position,
|
112 |
+
(cv.TERM_CRITERIA_MAX_ITER + cv.TERM_CRITERIA_EPS, 20000, 0.00001))
|
113 |
+
self.rotation, _ = cv.Rodrigues(rvec)
|
114 |
+
self.position = - np.transpose(self.rotation) @ t
|
115 |
+
|
116 |
+
def from_homography(self, homography):
|
117 |
+
"""
|
118 |
+
This method initializes the essential camera parameters from the homography between the world plane of the pitch
|
119 |
+
and the image. It is based on the extraction of the calibration matrix from the homography (Algorithm 8.2 of
|
120 |
+
Multiple View Geometry in computer vision, p225), then using the relation between the camera parameters and the
|
121 |
+
same homography, we extract rough rotation and position estimates (Example 8.1 of Multiple View Geometry in
|
122 |
+
computer vision, p196).
|
123 |
+
:param homography: The homography that captures the transformation between the 3D flat model of the soccer pitch
|
124 |
+
and its image.
|
125 |
+
"""
|
126 |
+
success, _ = self.estimate_calibration_matrix_from_plane_homography(homography)
|
127 |
+
if not success:
|
128 |
+
return False
|
129 |
+
|
130 |
+
hprim = np.linalg.inv(self.calibration) @ homography
|
131 |
+
lambda1 = 1 / np.linalg.norm(hprim[:, 0])
|
132 |
+
lambda2 = 1 / np.linalg.norm(hprim[:, 1])
|
133 |
+
lambda3 = np.sqrt(lambda1 * lambda2)
|
134 |
+
|
135 |
+
r0 = hprim[:, 0] * lambda1
|
136 |
+
r1 = hprim[:, 1] * lambda2
|
137 |
+
r2 = np.cross(r0, r1)
|
138 |
+
|
139 |
+
R = np.column_stack((r0, r1, r2))
|
140 |
+
u, s, vh = np.linalg.svd(R)
|
141 |
+
R = u @ vh
|
142 |
+
if np.linalg.det(R) < 0:
|
143 |
+
u[:, 2] *= -1
|
144 |
+
R = u @ vh
|
145 |
+
self.rotation = R
|
146 |
+
t = hprim[:, 2] * lambda3
|
147 |
+
self.position = - np.transpose(R) @ t
|
148 |
+
return True
|
149 |
+
|
150 |
+
def to_json_parameters(self):
|
151 |
+
"""
|
152 |
+
Saves camera to a JSON serializable dictionary.
|
153 |
+
:return: The dictionary
|
154 |
+
"""
|
155 |
+
pan, tilt, roll = rotation_matrix_to_pan_tilt_roll(self.rotation)
|
156 |
+
camera_dict = {
|
157 |
+
"pan_degrees": pan * 180. / np.pi,
|
158 |
+
"tilt_degrees": tilt * 180. / np.pi,
|
159 |
+
"roll_degrees": roll * 180. / np.pi,
|
160 |
+
"position_meters": self.position.tolist(),
|
161 |
+
"x_focal_length": self.xfocal_length,
|
162 |
+
"y_focal_length": self.yfocal_length,
|
163 |
+
"principal_point": [self.principal_point[0], self.principal_point[1]],
|
164 |
+
"radial_distortion": self.radial_distortion.tolist(),
|
165 |
+
"tangential_distortion": self.tangential_disto.tolist(),
|
166 |
+
"thin_prism_distortion": self.thin_prism_disto.tolist()
|
167 |
+
|
168 |
+
}
|
169 |
+
return camera_dict
|
170 |
+
|
171 |
+
def from_json_parameters(self, calib_json_object):
|
172 |
+
"""
|
173 |
+
Loads camera parameters from dictionary.
|
174 |
+
:param calib_json_object: the dictionary containing camera parameters.
|
175 |
+
"""
|
176 |
+
self.principal_point = calib_json_object["principal_point"]
|
177 |
+
self.image_width = 2 * self.principal_point[0]
|
178 |
+
self.image_height = 2 * self.principal_point[1]
|
179 |
+
self.xfocal_length = calib_json_object["x_focal_length"]
|
180 |
+
self.yfocal_length = calib_json_object["y_focal_length"]
|
181 |
+
|
182 |
+
self.calibration = np.array([
|
183 |
+
[self.xfocal_length, 0, self.principal_point[0]],
|
184 |
+
[0, self.yfocal_length, self.principal_point[1]],
|
185 |
+
[0, 0, 1]
|
186 |
+
], dtype='float')
|
187 |
+
|
188 |
+
pan = calib_json_object['pan_degrees'] * np.pi / 180.
|
189 |
+
tilt = calib_json_object['tilt_degrees'] * np.pi / 180.
|
190 |
+
roll = calib_json_object['roll_degrees'] * np.pi / 180.
|
191 |
+
|
192 |
+
self.rotation = np.array([
|
193 |
+
[-np.sin(pan) * np.sin(roll) * np.cos(tilt) + np.cos(pan) * np.cos(roll),
|
194 |
+
np.sin(pan) * np.cos(roll) + np.sin(roll) * np.cos(pan) * np.cos(tilt), np.sin(roll) * np.sin(tilt)],
|
195 |
+
[-np.sin(pan) * np.cos(roll) * np.cos(tilt) - np.sin(roll) * np.cos(pan),
|
196 |
+
-np.sin(pan) * np.sin(roll) + np.cos(pan) * np.cos(roll) * np.cos(tilt), np.sin(tilt) * np.cos(roll)],
|
197 |
+
[np.sin(pan) * np.sin(tilt), -np.sin(tilt) * np.cos(pan), np.cos(tilt)]
|
198 |
+
], dtype='float')
|
199 |
+
|
200 |
+
self.rotation = np.transpose(pan_tilt_roll_to_orientation(pan, tilt, roll))
|
201 |
+
|
202 |
+
self.position = np.array(calib_json_object['position_meters'], dtype='float')
|
203 |
+
|
204 |
+
self.radial_distortion = np.array(calib_json_object['radial_distortion'], dtype='float')
|
205 |
+
self.tangential_disto = np.array(calib_json_object['tangential_distortion'], dtype='float')
|
206 |
+
self.thin_prism_disto = np.array(calib_json_object['thin_prism_distortion'], dtype='float')
|
207 |
+
|
208 |
+
def distort(self, point):
|
209 |
+
"""
|
210 |
+
Given a point in the normalized image plane, apply distortion
|
211 |
+
:param point: 2D point on the normalized image plane
|
212 |
+
:return: 2D distorted point
|
213 |
+
"""
|
214 |
+
numerator = 1
|
215 |
+
denominator = 1
|
216 |
+
radius = np.sqrt(point[0] * point[0] + point[1] * point[1])
|
217 |
+
|
218 |
+
for i in range(3):
|
219 |
+
k = self.radial_distortion[i]
|
220 |
+
numerator += k * radius ** (2 * (i + 1))
|
221 |
+
k2n = self.radial_distortion[i + 3]
|
222 |
+
denominator += k2n * radius ** (2 * (i + 1))
|
223 |
+
|
224 |
+
radial_distortion_factor = numerator / denominator
|
225 |
+
xpp = point[0] * radial_distortion_factor + \
|
226 |
+
2 * self.tangential_disto[0] * point[0] * point[1] + self.tangential_disto[1] * (
|
227 |
+
radius ** 2 + 2 * point[0] ** 2) + \
|
228 |
+
self.thin_prism_disto[0] * radius ** 2 + self.thin_prism_disto[1] * radius ** 4
|
229 |
+
ypp = point[1] * radial_distortion_factor + \
|
230 |
+
2 * self.tangential_disto[1] * point[0] * point[1] + self.tangential_disto[0] * (
|
231 |
+
radius ** 2 + 2 * point[1] ** 2) + \
|
232 |
+
self.thin_prism_disto[2] * radius ** 2 + self.thin_prism_disto[3] * radius ** 4
|
233 |
+
return np.array([xpp, ypp], dtype=np.float32)
|
234 |
+
|
235 |
+
def project_point(self, point3D, distort=True):
|
236 |
+
"""
|
237 |
+
Uses current camera parameters to predict where a 3D point is seen by the camera.
|
238 |
+
:param point3D: The 3D point in world coordinates.
|
239 |
+
:param distort: optional parameter to allow projection without distortion.
|
240 |
+
:return: The 2D coordinates of the imaged point
|
241 |
+
"""
|
242 |
+
point = point3D - self.position
|
243 |
+
rotated_point = self.rotation @ np.transpose(point)
|
244 |
+
if rotated_point[2] <= 1e-3 :
|
245 |
+
return np.zeros(3)
|
246 |
+
rotated_point = rotated_point / rotated_point[2]
|
247 |
+
if distort:
|
248 |
+
distorted_point = self.distort(rotated_point)
|
249 |
+
else:
|
250 |
+
distorted_point = rotated_point
|
251 |
+
x = distorted_point[0] * self.xfocal_length + self.principal_point[0]
|
252 |
+
y = distorted_point[1] * self.yfocal_length + self.principal_point[1]
|
253 |
+
return np.array([x, y, 1])
|
254 |
+
|
255 |
+
def scale_resolution(self, factor):
|
256 |
+
"""
|
257 |
+
Adapts the internal parameters for image resolution changes
|
258 |
+
:param factor: scaling factor
|
259 |
+
"""
|
260 |
+
self.xfocal_length = self.xfocal_length * factor
|
261 |
+
self.yfocal_length = self.yfocal_length * factor
|
262 |
+
self.image_width = self.image_width * factor
|
263 |
+
self.image_height = self.image_height * factor
|
264 |
+
|
265 |
+
self.principal_point = (self.image_width / 2, self.image_height / 2)
|
266 |
+
|
267 |
+
self.calibration = np.array([
|
268 |
+
[self.xfocal_length, 0, self.principal_point[0]],
|
269 |
+
[0, self.yfocal_length, self.principal_point[1]],
|
270 |
+
[0, 0, 1]
|
271 |
+
], dtype='float')
|
272 |
+
|
273 |
+
def draw_corners(self, image, color=(0, 255, 0)):
|
274 |
+
"""
|
275 |
+
Draw the corners of a standard soccer pitch in the image.
|
276 |
+
:param image: cv image
|
277 |
+
:param color
|
278 |
+
:return: the image mat modified.
|
279 |
+
"""
|
280 |
+
field = SoccerPitch()
|
281 |
+
for pt3D in field.point_dict.values():
|
282 |
+
projected = self.project_point(pt3D)
|
283 |
+
if projected[2] == 0.:
|
284 |
+
continue
|
285 |
+
projected /= projected[2]
|
286 |
+
if 0 < projected[0] < self.image_width and 0 < projected[1] < self.image_height:
|
287 |
+
cv.circle(image, (int(projected[0]), int(projected[1])), 3, color, 2)
|
288 |
+
return image
|
289 |
+
|
290 |
+
def draw_pitch(self, image, color=(0, 255, 0)):
|
291 |
+
"""
|
292 |
+
Draws all the lines of the pitch on the image.
|
293 |
+
:param image
|
294 |
+
:param color
|
295 |
+
:return: modified image
|
296 |
+
"""
|
297 |
+
field = SoccerPitch()
|
298 |
+
|
299 |
+
polylines = field.sample_field_points()
|
300 |
+
for line in polylines.values():
|
301 |
+
prev_point = self.project_point(line[0])
|
302 |
+
for point in line[1:]:
|
303 |
+
projected = self.project_point(point)
|
304 |
+
if projected[2] == 0.:
|
305 |
+
continue
|
306 |
+
projected /= projected[2]
|
307 |
+
if 0 < projected[0] < self.image_width and 0 < projected[1] < self.image_height:
|
308 |
+
cv.line(image, (int(prev_point[0]), int(prev_point[1])), (int(projected[0]), int(projected[1])),
|
309 |
+
color, 1)
|
310 |
+
prev_point = projected
|
311 |
+
return image
|
312 |
+
|
313 |
+
def draw_colorful_pitch(self, image, palette):
|
314 |
+
"""
|
315 |
+
Draws all the lines of the pitch on the image, each line color is specified by the palette argument.
|
316 |
+
|
317 |
+
:param image:
|
318 |
+
:param palette: dictionary associating line classes names with their BGR color.
|
319 |
+
:return: modified image
|
320 |
+
"""
|
321 |
+
field = SoccerPitch()
|
322 |
+
|
323 |
+
polylines = field.sample_field_points()
|
324 |
+
for key, line in polylines.items():
|
325 |
+
if key not in palette.keys():
|
326 |
+
print(f"Can't draw {key}")
|
327 |
+
continue
|
328 |
+
prev_point = self.project_point(line[0])
|
329 |
+
for point in line[1:]:
|
330 |
+
projected = self.project_point(point)
|
331 |
+
if projected[2] == 0.:
|
332 |
+
continue
|
333 |
+
projected /= projected[2]
|
334 |
+
if 0 < projected[0] < self.image_width and 0 < projected[1] < self.image_height:
|
335 |
+
# BGR color
|
336 |
+
cv.line(image, (int(prev_point[0]), int(prev_point[1])), (int(projected[0]), int(projected[1])),
|
337 |
+
palette[key][::-1], 1)
|
338 |
+
prev_point = projected
|
339 |
+
return image
|
340 |
+
|
341 |
+
def estimate_calibration_matrix_from_plane_homography(self, homography):
|
342 |
+
"""
|
343 |
+
This method initializes the calibration matrix from the homography between the world plane of the pitch
|
344 |
+
and the image. It is based on the extraction of the calibration matrix from the homography (Algorithm 8.2 of
|
345 |
+
Multiple View Geometry in computer vision, p225). The extraction is sensitive to noise, which is why we keep the
|
346 |
+
principal point in the middle of the image rather than using the one extracted by this method.
|
347 |
+
:param homography: homography between the world plane of the pitch and the image
|
348 |
+
"""
|
349 |
+
H = np.reshape(homography, (9,))
|
350 |
+
A = np.zeros((5, 6))
|
351 |
+
A[0, 1] = 1.
|
352 |
+
A[1, 0] = 1.
|
353 |
+
A[1, 2] = -1.
|
354 |
+
A[2, 3] = self.principal_point[1] / self.principal_point[0]
|
355 |
+
A[2, 4] = -1.0
|
356 |
+
A[3, 0] = H[0] * H[1]
|
357 |
+
A[3, 1] = H[0] * H[4] + H[1] * H[3]
|
358 |
+
A[3, 2] = H[3] * H[4]
|
359 |
+
A[3, 3] = H[0] * H[7] + H[1] * H[6]
|
360 |
+
A[3, 4] = H[3] * H[7] + H[4] * H[6]
|
361 |
+
A[3, 5] = H[6] * H[7]
|
362 |
+
A[4, 0] = H[0] * H[0] - H[1] * H[1]
|
363 |
+
A[4, 1] = 2 * H[0] * H[3] - 2 * H[1] * H[4]
|
364 |
+
A[4, 2] = H[3] * H[3] - H[4] * H[4]
|
365 |
+
A[4, 3] = 2 * H[0] * H[6] - 2 * H[1] * H[7]
|
366 |
+
A[4, 4] = 2 * H[3] * H[6] - 2 * H[4] * H[7]
|
367 |
+
A[4, 5] = H[6] * H[6] - H[7] * H[7]
|
368 |
+
|
369 |
+
u, s, vh = np.linalg.svd(A)
|
370 |
+
w = vh[-1]
|
371 |
+
W = np.zeros((3, 3))
|
372 |
+
W[0, 0] = w[0] / w[5]
|
373 |
+
W[0, 1] = w[1] / w[5]
|
374 |
+
W[0, 2] = w[3] / w[5]
|
375 |
+
W[1, 0] = w[1] / w[5]
|
376 |
+
W[1, 1] = w[2] / w[5]
|
377 |
+
W[1, 2] = w[4] / w[5]
|
378 |
+
W[2, 0] = w[3] / w[5]
|
379 |
+
W[2, 1] = w[4] / w[5]
|
380 |
+
W[2, 2] = w[5] / w[5]
|
381 |
+
|
382 |
+
try:
|
383 |
+
Ktinv = np.linalg.cholesky(W)
|
384 |
+
except np.linalg.LinAlgError:
|
385 |
+
K = np.eye(3)
|
386 |
+
return False, K
|
387 |
+
|
388 |
+
K = np.linalg.inv(np.transpose(Ktinv))
|
389 |
+
K /= K[2, 2]
|
390 |
+
|
391 |
+
self.xfocal_length = K[0, 0]
|
392 |
+
self.yfocal_length = K[1, 1]
|
393 |
+
# the principal point estimated by this method is very noisy, better keep it in the center of the image
|
394 |
+
self.principal_point = (self.image_width / 2, self.image_height / 2)
|
395 |
+
# self.principal_point = (K[0,2], K[1,2])
|
396 |
+
self.calibration = np.array([
|
397 |
+
[self.xfocal_length, 0, self.principal_point[0]],
|
398 |
+
[0, self.yfocal_length, self.principal_point[1]],
|
399 |
+
[0, 0, 1]
|
400 |
+
], dtype='float')
|
401 |
+
return True, K
|
402 |
+
|
sn_calibration/src/dataloader.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
DataLoader used to train the segmentation network used for the prediction of extremities.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
from argparse import ArgumentParser
|
9 |
+
|
10 |
+
import cv2 as cv
|
11 |
+
import numpy as np
|
12 |
+
from torch.utils.data import Dataset
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from src.soccerpitch import SoccerPitch
|
16 |
+
|
17 |
+
|
18 |
+
class SoccerNetDataset(Dataset):
|
19 |
+
def __init__(self,
|
20 |
+
datasetpath,
|
21 |
+
split="test",
|
22 |
+
width=640,
|
23 |
+
height=360,
|
24 |
+
mean="../resources/mean.npy",
|
25 |
+
std="../resources/std.npy"):
|
26 |
+
self.mean = np.load(mean)
|
27 |
+
self.std = np.load(std)
|
28 |
+
self.width = width
|
29 |
+
self.height = height
|
30 |
+
|
31 |
+
dataset_dir = os.path.join(datasetpath, split)
|
32 |
+
if not os.path.exists(dataset_dir):
|
33 |
+
print("Invalid dataset path !")
|
34 |
+
exit(-1)
|
35 |
+
|
36 |
+
frames = [f for f in os.listdir(dataset_dir) if ".jpg" in f]
|
37 |
+
|
38 |
+
self.data = []
|
39 |
+
self.n_samples = 0
|
40 |
+
for frame in frames:
|
41 |
+
|
42 |
+
frame_index = frame.split(".")[0]
|
43 |
+
annotation_file = os.path.join(dataset_dir, f"{frame_index}.json")
|
44 |
+
if not os.path.exists(annotation_file):
|
45 |
+
continue
|
46 |
+
with open(annotation_file, "r") as f:
|
47 |
+
groundtruth_lines = json.load(f)
|
48 |
+
img_path = os.path.join(dataset_dir, frame)
|
49 |
+
if groundtruth_lines:
|
50 |
+
self.data.append({
|
51 |
+
"image_path": img_path,
|
52 |
+
"annotations": groundtruth_lines,
|
53 |
+
})
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return len(self.data)
|
57 |
+
|
58 |
+
def __getitem__(self, index):
|
59 |
+
item = self.data[index]
|
60 |
+
|
61 |
+
img = cv.imread(item["image_path"])
|
62 |
+
img = cv.resize(img, (self.width, self.height), interpolation=cv.INTER_LINEAR)
|
63 |
+
|
64 |
+
mask = np.zeros(img.shape[:-1], dtype=np.uint8)
|
65 |
+
img = np.asarray(img, np.float32) / 255.
|
66 |
+
img -= self.mean
|
67 |
+
img /= self.std
|
68 |
+
img = img.transpose((2, 0, 1))
|
69 |
+
for class_number, class_ in enumerate(SoccerPitch.lines_classes):
|
70 |
+
if class_ in item["annotations"].keys():
|
71 |
+
key = class_
|
72 |
+
line = item["annotations"][key]
|
73 |
+
prev_point = line[0]
|
74 |
+
for i in range(1, len(line)):
|
75 |
+
next_point = line[i]
|
76 |
+
cv.line(mask,
|
77 |
+
(int(prev_point["x"] * mask.shape[1]), int(prev_point["y"] * mask.shape[0])),
|
78 |
+
(int(next_point["x"] * mask.shape[1]), int(next_point["y"] * mask.shape[0])),
|
79 |
+
class_number + 1,
|
80 |
+
2)
|
81 |
+
prev_point = next_point
|
82 |
+
return img, mask
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
|
87 |
+
# Load the arguments
|
88 |
+
parser = ArgumentParser(description='dataloader')
|
89 |
+
|
90 |
+
parser.add_argument('--SoccerNet_path', default="./annotations/", type=str,
|
91 |
+
help='Path to the SoccerNet-V3 dataset folder')
|
92 |
+
parser.add_argument('--tiny', required=False, type=int, default=None, help='Select a subset of x games')
|
93 |
+
parser.add_argument('--split', required=False, type=str, default="test", help='Select the split of data')
|
94 |
+
parser.add_argument('--num_workers', required=False, type=int, default=4,
|
95 |
+
help='number of workers for the dataloader')
|
96 |
+
parser.add_argument('--resolution_width', required=False, type=int, default=1920,
|
97 |
+
help='width resolution of the images')
|
98 |
+
parser.add_argument('--resolution_height', required=False, type=int, default=1080,
|
99 |
+
help='height resolution of the images')
|
100 |
+
parser.add_argument('--preload_images', action='store_true',
|
101 |
+
help="Preload the images when constructing the dataset")
|
102 |
+
parser.add_argument('--zipped_images', action='store_true', help="Read images from zipped folder")
|
103 |
+
|
104 |
+
args = parser.parse_args()
|
105 |
+
|
106 |
+
start_time = time.time()
|
107 |
+
soccernet = SoccerNetDataset(args.SoccerNet_path, split=args.split)
|
108 |
+
with tqdm(enumerate(soccernet), total=len(soccernet), ncols=160) as t:
|
109 |
+
for i, data in t:
|
110 |
+
img = soccernet[i][0].astype(np.uint8).transpose((1, 2, 0))
|
111 |
+
print(img.shape)
|
112 |
+
print(img.dtype)
|
113 |
+
cv.imshow("Normalized image", img)
|
114 |
+
cv.waitKey(0)
|
115 |
+
cv.destroyAllWindows()
|
116 |
+
print(data[1].shape)
|
117 |
+
cv.imshow("Mask", soccernet[i][1].astype(np.uint8))
|
118 |
+
cv.waitKey(0)
|
119 |
+
cv.destroyAllWindows()
|
120 |
+
continue
|
121 |
+
end_time = time.time()
|
122 |
+
print(end_time - start_time)
|
sn_calibration/src/detect_extremities.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import copy
|
3 |
+
import json
|
4 |
+
import os.path
|
5 |
+
import random
|
6 |
+
from collections import deque
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import cv2 as cv
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.backends.cudnn
|
13 |
+
import torch.nn as nn
|
14 |
+
from PIL import Image
|
15 |
+
from torchvision.models.segmentation import deeplabv3_resnet50
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
from soccerpitch import SoccerPitch
|
19 |
+
|
20 |
+
|
21 |
+
def generate_class_synthesis(semantic_mask, radius):
|
22 |
+
"""
|
23 |
+
This function selects for each class present in the semantic mask, a set of circles that cover most of the semantic
|
24 |
+
class blobs.
|
25 |
+
:param semantic_mask: a image containing the segmentation predictions
|
26 |
+
:param radius: circle radius
|
27 |
+
:return: a dictionary which associates with each class detected a list of points ( the circles centers)
|
28 |
+
"""
|
29 |
+
buckets = dict()
|
30 |
+
kernel = np.ones((5, 5), np.uint8)
|
31 |
+
semantic_mask = cv.erode(semantic_mask, kernel, iterations=1)
|
32 |
+
for k, class_name in enumerate(SoccerPitch.lines_classes):
|
33 |
+
mask = semantic_mask == k + 1
|
34 |
+
if mask.sum() > 0:
|
35 |
+
disk_list = synthesize_mask(mask, radius)
|
36 |
+
if len(disk_list):
|
37 |
+
buckets[class_name] = disk_list
|
38 |
+
|
39 |
+
return buckets
|
40 |
+
|
41 |
+
|
42 |
+
def join_points(point_list, maxdist):
|
43 |
+
"""
|
44 |
+
Given a list of points that were extracted from the blobs belonging to a same semantic class, this function creates
|
45 |
+
polylines by linking close points together if their distance is below the maxdist threshold.
|
46 |
+
:param point_list: List of points of the same line class
|
47 |
+
:param maxdist: minimal distance between two polylines.
|
48 |
+
:return: a list of polylines
|
49 |
+
"""
|
50 |
+
polylines = []
|
51 |
+
|
52 |
+
if not len(point_list):
|
53 |
+
return polylines
|
54 |
+
head = point_list[0]
|
55 |
+
tail = point_list[0]
|
56 |
+
polyline = deque()
|
57 |
+
polyline.append(point_list[0])
|
58 |
+
remaining_points = copy.deepcopy(point_list[1:])
|
59 |
+
|
60 |
+
while len(remaining_points) > 0:
|
61 |
+
min_dist_tail = 1000
|
62 |
+
min_dist_head = 1000
|
63 |
+
best_head = -1
|
64 |
+
best_tail = -1
|
65 |
+
for j, point in enumerate(remaining_points):
|
66 |
+
dist_tail = np.sqrt(np.sum(np.square(point - tail)))
|
67 |
+
dist_head = np.sqrt(np.sum(np.square(point - head)))
|
68 |
+
if dist_tail < min_dist_tail:
|
69 |
+
min_dist_tail = dist_tail
|
70 |
+
best_tail = j
|
71 |
+
if dist_head < min_dist_head:
|
72 |
+
min_dist_head = dist_head
|
73 |
+
best_head = j
|
74 |
+
|
75 |
+
if min_dist_head <= min_dist_tail and min_dist_head < maxdist:
|
76 |
+
polyline.appendleft(remaining_points[best_head])
|
77 |
+
head = polyline[0]
|
78 |
+
remaining_points.pop(best_head)
|
79 |
+
elif min_dist_tail < min_dist_head and min_dist_tail < maxdist:
|
80 |
+
polyline.append(remaining_points[best_tail])
|
81 |
+
tail = polyline[-1]
|
82 |
+
remaining_points.pop(best_tail)
|
83 |
+
else:
|
84 |
+
polylines.append(list(polyline.copy()))
|
85 |
+
head = remaining_points[0]
|
86 |
+
tail = remaining_points[0]
|
87 |
+
polyline = deque()
|
88 |
+
polyline.append(head)
|
89 |
+
remaining_points.pop(0)
|
90 |
+
polylines.append(list(polyline))
|
91 |
+
return polylines
|
92 |
+
|
93 |
+
|
94 |
+
def get_line_extremities(buckets, maxdist, width, height):
|
95 |
+
"""
|
96 |
+
Given the dictionary {lines_class: points}, finds plausible extremities of each line, i.e the extremities
|
97 |
+
of the longest polyline that can be built on the class blobs, and normalize its coordinates
|
98 |
+
by the image size.
|
99 |
+
:param buckets: The dictionary associating line classes to the set of circle centers that covers best the class
|
100 |
+
prediction blobs in the segmentation mask
|
101 |
+
:param maxdist: the maximal distance between two circle centers belonging to the same blob (heuristic)
|
102 |
+
:param width: image width
|
103 |
+
:param height: image height
|
104 |
+
:return: a dictionary associating to each class its extremities
|
105 |
+
"""
|
106 |
+
extremities = dict()
|
107 |
+
for class_name, disks_list in buckets.items():
|
108 |
+
polyline_list = join_points(disks_list, maxdist)
|
109 |
+
max_len = 0
|
110 |
+
longest_polyline = []
|
111 |
+
for polyline in polyline_list:
|
112 |
+
if len(polyline) > max_len:
|
113 |
+
max_len = len(polyline)
|
114 |
+
longest_polyline = polyline
|
115 |
+
extremities[class_name] = [
|
116 |
+
{'x': longest_polyline[0][1] / width, 'y': longest_polyline[0][0] / height},
|
117 |
+
{'x': longest_polyline[-1][1] / width, 'y': longest_polyline[-1][0] / height}
|
118 |
+
]
|
119 |
+
return extremities
|
120 |
+
|
121 |
+
|
122 |
+
def get_support_center(mask, start, disk_radius, min_support=0.1):
|
123 |
+
"""
|
124 |
+
Returns the barycenter of the True pixels under the area of the mask delimited by the circle of center start and
|
125 |
+
radius of disk_radius pixels.
|
126 |
+
:param mask: Boolean mask
|
127 |
+
:param start: A point located on a true pixel of the mask
|
128 |
+
:param disk_radius: the radius of the circles
|
129 |
+
:param min_support: proportion of the area under the circle area that should be True in order to get enough support
|
130 |
+
:return: A boolean indicating if there is enough support in the circle area, the barycenter of the True pixels under
|
131 |
+
the circle
|
132 |
+
"""
|
133 |
+
x = int(start[0])
|
134 |
+
y = int(start[1])
|
135 |
+
support_pixels = 1
|
136 |
+
result = [x, y]
|
137 |
+
xstart = x - disk_radius
|
138 |
+
if xstart < 0:
|
139 |
+
xstart = 0
|
140 |
+
xend = x + disk_radius
|
141 |
+
if xend > mask.shape[0]:
|
142 |
+
xend = mask.shape[0] - 1
|
143 |
+
|
144 |
+
ystart = y - disk_radius
|
145 |
+
if ystart < 0:
|
146 |
+
ystart = 0
|
147 |
+
yend = y + disk_radius
|
148 |
+
if yend > mask.shape[1]:
|
149 |
+
yend = mask.shape[1] - 1
|
150 |
+
|
151 |
+
for i in range(xstart, xend + 1):
|
152 |
+
for j in range(ystart, yend + 1):
|
153 |
+
dist = np.sqrt(np.square(x - i) + np.square(y - j))
|
154 |
+
if dist < disk_radius and mask[i, j] > 0:
|
155 |
+
support_pixels += 1
|
156 |
+
result[0] += i
|
157 |
+
result[1] += j
|
158 |
+
support = True
|
159 |
+
if support_pixels < min_support * np.square(disk_radius) * np.pi:
|
160 |
+
support = False
|
161 |
+
|
162 |
+
result = np.array(result)
|
163 |
+
result = np.true_divide(result, support_pixels)
|
164 |
+
|
165 |
+
return support, result
|
166 |
+
|
167 |
+
|
168 |
+
def synthesize_mask(semantic_mask, disk_radius):
|
169 |
+
"""
|
170 |
+
Fits circles on the True pixels of the mask and returns those which have enough support : meaning that the
|
171 |
+
proportion of the area of the circle covering True pixels is higher that a certain threshold in order to avoid
|
172 |
+
fitting circles on alone pixels.
|
173 |
+
:param semantic_mask: boolean mask
|
174 |
+
:param disk_radius: radius of the circles
|
175 |
+
:return: a list of disk centers, that have enough support
|
176 |
+
"""
|
177 |
+
mask = semantic_mask.copy().astype(np.uint8)
|
178 |
+
points = np.transpose(np.nonzero(mask))
|
179 |
+
disks = []
|
180 |
+
while len(points):
|
181 |
+
|
182 |
+
start = random.choice(points)
|
183 |
+
dist = 10.
|
184 |
+
success = True
|
185 |
+
while dist > 1.:
|
186 |
+
enough_support, center = get_support_center(mask, start, disk_radius)
|
187 |
+
if not enough_support:
|
188 |
+
bad_point = np.round(center).astype(np.int32)
|
189 |
+
cv.circle(mask, (bad_point[1], bad_point[0]), disk_radius, (0), -1)
|
190 |
+
success = False
|
191 |
+
dist = np.sqrt(np.sum(np.square(center - start)))
|
192 |
+
start = center
|
193 |
+
if success:
|
194 |
+
disks.append(np.round(start).astype(np.int32))
|
195 |
+
cv.circle(mask, (disks[-1][1], disks[-1][0]), disk_radius, 0, -1)
|
196 |
+
points = np.transpose(np.nonzero(mask))
|
197 |
+
|
198 |
+
return disks
|
199 |
+
|
200 |
+
|
201 |
+
class SegmentationNetwork:
|
202 |
+
def __init__(self, model_file, mean_file, std_file, num_classes=29, width=640, height=360):
|
203 |
+
file_path = Path(model_file).resolve()
|
204 |
+
model = nn.DataParallel(deeplabv3_resnet50(pretrained=False, num_classes=num_classes))
|
205 |
+
self.init_weight(model, nn.init.kaiming_normal_,
|
206 |
+
nn.BatchNorm2d, 1e-3, 0.1,
|
207 |
+
mode='fan_in')
|
208 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
209 |
+
checkpoint = torch.load(str(file_path), map_location=self.device)
|
210 |
+
model.load_state_dict(checkpoint["model"])
|
211 |
+
model.eval()
|
212 |
+
self.model = model.to(self.device)
|
213 |
+
file_path = Path(mean_file).resolve()
|
214 |
+
self.mean = np.load(str(file_path))
|
215 |
+
file_path = Path(std_file).resolve()
|
216 |
+
self.std = np.load(str(file_path))
|
217 |
+
self.width = width
|
218 |
+
self.height = height
|
219 |
+
|
220 |
+
def init_weight(self, feature, conv_init, norm_layer, bn_eps, bn_momentum,
|
221 |
+
**kwargs):
|
222 |
+
for name, m in feature.named_modules():
|
223 |
+
if isinstance(m, (nn.Conv2d, nn.Conv3d)):
|
224 |
+
conv_init(m.weight, **kwargs)
|
225 |
+
elif isinstance(m, norm_layer):
|
226 |
+
m.eps = bn_eps
|
227 |
+
m.momentum = bn_momentum
|
228 |
+
nn.init.constant_(m.weight, 1)
|
229 |
+
nn.init.constant_(m.bias, 0)
|
230 |
+
|
231 |
+
def analyse_image(self, image):
|
232 |
+
"""
|
233 |
+
Process image and perform inference, returns mask of detected classes
|
234 |
+
:param image: BGR image
|
235 |
+
:return: predicted classes mask
|
236 |
+
"""
|
237 |
+
img = cv.resize(image, (self.width, self.height), interpolation=cv.INTER_LINEAR)
|
238 |
+
img = np.asarray(img, np.float32) / 255.
|
239 |
+
img = (img - self.mean) / self.std
|
240 |
+
img = img.transpose((2, 0, 1))
|
241 |
+
img = torch.from_numpy(img).to(self.device).unsqueeze(0)
|
242 |
+
|
243 |
+
cuda_result = self.model.forward(img.float())
|
244 |
+
output = cuda_result['out'].data[0].cpu().numpy()
|
245 |
+
output = output.transpose(1, 2, 0)
|
246 |
+
output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
|
247 |
+
|
248 |
+
return output
|
249 |
+
|
250 |
+
|
251 |
+
if __name__ == "__main__":
|
252 |
+
parser = argparse.ArgumentParser(description='Test')
|
253 |
+
|
254 |
+
parser.add_argument('-s', '--soccernet', default="/home/fmg/data/SN23/calibration-2023-bis/", type=str,
|
255 |
+
help='Path to the SoccerNet-V3 dataset folder')
|
256 |
+
parser.add_argument('-p', '--prediction', default="/home/fmg/results/SN23-tests/", required=False, type=str,
|
257 |
+
help="Path to the prediction folder")
|
258 |
+
parser.add_argument('--split', required=False, type=str, default="challenge", help='Select the split of data')
|
259 |
+
parser.add_argument('--masks', required=False, type=bool, default=False, help='Save masks in prediction directory')
|
260 |
+
parser.add_argument('--resolution_width', required=False, type=int, default=640,
|
261 |
+
help='width resolution of the images')
|
262 |
+
parser.add_argument('--resolution_height', required=False, type=int, default=360,
|
263 |
+
help='height resolution of the images')
|
264 |
+
args = parser.parse_args()
|
265 |
+
|
266 |
+
lines_palette = [0, 0, 0]
|
267 |
+
for line_class in SoccerPitch.lines_classes:
|
268 |
+
lines_palette.extend(SoccerPitch.palette[line_class])
|
269 |
+
|
270 |
+
calib_net = SegmentationNetwork(
|
271 |
+
"../resources/soccer_pitch_segmentation.pth",
|
272 |
+
"../resources/mean.npy",
|
273 |
+
"../resources/std.npy")
|
274 |
+
|
275 |
+
dataset_dir = os.path.join(args.soccernet, args.split)
|
276 |
+
if not os.path.exists(dataset_dir):
|
277 |
+
print("Invalid dataset path !")
|
278 |
+
exit(-1)
|
279 |
+
|
280 |
+
with open(os.path.join(dataset_dir, "per_match_info.json"), 'r') as f:
|
281 |
+
match_info = json.load(f)
|
282 |
+
|
283 |
+
with tqdm(enumerate(match_info.keys()), total=len(match_info.keys()), ncols=160) as t:
|
284 |
+
for i, match in t:
|
285 |
+
frame_list = match_info[match].keys()
|
286 |
+
|
287 |
+
for frame in frame_list:
|
288 |
+
|
289 |
+
output_prediction_folder = os.path.join(args.prediction, args.split)
|
290 |
+
if not os.path.exists(output_prediction_folder):
|
291 |
+
os.makedirs(output_prediction_folder)
|
292 |
+
prediction = dict()
|
293 |
+
count = 0
|
294 |
+
|
295 |
+
frame_path = os.path.join(dataset_dir, frame)
|
296 |
+
|
297 |
+
frame_index = frame.split(".")[0]
|
298 |
+
|
299 |
+
image = cv.imread(frame_path)
|
300 |
+
semlines = calib_net.analyse_image(image)
|
301 |
+
if args.masks:
|
302 |
+
mask = Image.fromarray(semlines.astype(np.uint8)).convert('P')
|
303 |
+
mask.putpalette(lines_palette)
|
304 |
+
mask_file = os.path.join(output_prediction_folder, frame)
|
305 |
+
mask.save(mask_file)
|
306 |
+
skeletons = generate_class_synthesis(semlines, 6)
|
307 |
+
extremities = get_line_extremities(skeletons, 40, args.resolution_width, args.resolution_height)
|
308 |
+
|
309 |
+
prediction = extremities
|
310 |
+
count += 1
|
311 |
+
|
312 |
+
prediction_file = os.path.join(output_prediction_folder, f"extremities_{frame_index}.json")
|
313 |
+
with open(prediction_file, "w") as f:
|
314 |
+
json.dump(prediction, f, indent=4)
|
sn_calibration/src/evalai_camera.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import zipfile
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import sys
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
sys.path.append("sn_calibration")
|
10 |
+
sys.path.append("sn_calibration/src")
|
11 |
+
|
12 |
+
from evaluate_camera import get_polylines, scale_points, evaluate_camera_prediction
|
13 |
+
from evaluate_extremities import mirror_labels
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
def evaluate(gt_zip, prediction_zip, width=960, height=540):
|
18 |
+
gt_archive = zipfile.ZipFile(gt_zip, 'r')
|
19 |
+
prediction_archive = zipfile.ZipFile(prediction_zip, 'r')
|
20 |
+
gt_jsons = gt_archive.namelist()
|
21 |
+
|
22 |
+
accuracies = []
|
23 |
+
precisions = []
|
24 |
+
recalls = []
|
25 |
+
dict_errors = {}
|
26 |
+
per_class_confusion_dict = {}
|
27 |
+
total_frames = 0
|
28 |
+
missed = 0
|
29 |
+
for gt_json in tqdm(gt_jsons):
|
30 |
+
|
31 |
+
#split, name = gt_json.split("/")
|
32 |
+
#pred_name = f"{split}/camera_{name}"
|
33 |
+
pred_name = f"camera_{gt_json}"
|
34 |
+
|
35 |
+
total_frames += 1
|
36 |
+
|
37 |
+
if pred_name not in prediction_archive.namelist():
|
38 |
+
missed += 1
|
39 |
+
continue
|
40 |
+
|
41 |
+
prediction = prediction_archive.read(pred_name)
|
42 |
+
prediction = json.loads(prediction.decode("utf-8"))
|
43 |
+
gt = gt_archive.read(gt_json)
|
44 |
+
gt = json.loads(gt.decode('utf-8'))
|
45 |
+
|
46 |
+
line_annotations = scale_points(gt, width, height)
|
47 |
+
|
48 |
+
img_groundtruth = line_annotations
|
49 |
+
|
50 |
+
img_prediction = get_polylines(prediction, width, height,
|
51 |
+
sampling_factor=0.9)
|
52 |
+
|
53 |
+
confusion1, per_class_conf1, reproj_errors1 = evaluate_camera_prediction(img_prediction,
|
54 |
+
img_groundtruth,
|
55 |
+
5)
|
56 |
+
|
57 |
+
confusion2, per_class_conf2, reproj_errors2 = evaluate_camera_prediction(img_prediction,
|
58 |
+
mirror_labels(img_groundtruth),
|
59 |
+
5)
|
60 |
+
|
61 |
+
accuracy1, accuracy2 = 0., 0.
|
62 |
+
if confusion1.sum() > 0:
|
63 |
+
accuracy1 = confusion1[0, 0] / confusion1.sum()
|
64 |
+
|
65 |
+
if confusion2.sum() > 0:
|
66 |
+
accuracy2 = confusion2[0, 0] / confusion2.sum()
|
67 |
+
|
68 |
+
if accuracy1 > accuracy2:
|
69 |
+
accuracy = accuracy1
|
70 |
+
confusion = confusion1
|
71 |
+
per_class_conf = per_class_conf1
|
72 |
+
reproj_errors = reproj_errors1
|
73 |
+
else:
|
74 |
+
accuracy = accuracy2
|
75 |
+
confusion = confusion2
|
76 |
+
per_class_conf = per_class_conf2
|
77 |
+
reproj_errors = reproj_errors2
|
78 |
+
|
79 |
+
accuracies.append(accuracy)
|
80 |
+
if confusion[0, :].sum() > 0:
|
81 |
+
precision = confusion[0, 0] / (confusion[0, :].sum())
|
82 |
+
precisions.append(precision)
|
83 |
+
if (confusion[0, 0] + confusion[1, 0]) > 0:
|
84 |
+
recall = confusion[0, 0] / (confusion[0, 0] + confusion[1, 0])
|
85 |
+
recalls.append(recall)
|
86 |
+
|
87 |
+
for line_class, errors in reproj_errors.items():
|
88 |
+
if line_class in dict_errors.keys():
|
89 |
+
dict_errors[line_class].extend(errors)
|
90 |
+
else:
|
91 |
+
dict_errors[line_class] = errors
|
92 |
+
|
93 |
+
for line_class, confusion_mat in per_class_conf.items():
|
94 |
+
if line_class in per_class_confusion_dict.keys():
|
95 |
+
per_class_confusion_dict[line_class] += confusion_mat
|
96 |
+
else:
|
97 |
+
per_class_confusion_dict[line_class] = confusion_mat
|
98 |
+
|
99 |
+
|
100 |
+
results = {}
|
101 |
+
results["completeness"] = (total_frames - missed) / total_frames
|
102 |
+
results["meanRecall"] = np.mean(recalls)
|
103 |
+
results["meanPrecision"] = np.mean(precisions)
|
104 |
+
results["meanAccuracies"] = np.mean(accuracies)
|
105 |
+
results["finalScore"] = results["completeness"] * results["meanAccuracies"]
|
106 |
+
|
107 |
+
|
108 |
+
for line_class, confusion_mat in per_class_confusion_dict.items():
|
109 |
+
class_accuracy = confusion_mat[0, 0] / confusion_mat.sum()
|
110 |
+
class_recall = confusion_mat[0, 0] / (confusion_mat[0, 0] + confusion_mat[1, 0])
|
111 |
+
class_precision = confusion_mat[0, 0] / (confusion_mat[0, 0] + confusion_mat[0, 1])
|
112 |
+
results[f"{line_class}Precision"] = class_precision
|
113 |
+
results[f"{line_class}Recall"] = class_recall
|
114 |
+
results[f"{line_class}Accuracy"] = class_accuracy
|
115 |
+
return results
|
116 |
+
|
117 |
+
|
118 |
+
if __name__ == "__main__":
|
119 |
+
parser = argparse.ArgumentParser(description='Evaluation camera calibration task')
|
120 |
+
|
121 |
+
parser.add_argument('-s', '--soccernet', default="/home/fmg/data/SN23/calibration-2023/test_secret.zip", type=str,
|
122 |
+
help='Path to the zip groundtruth folder')
|
123 |
+
parser.add_argument('-p', '--prediction', default="/home/fmg/results/SN23-tests/test.zip",
|
124 |
+
required=False, type=str,
|
125 |
+
help="Path to the zip prediction folder")
|
126 |
+
parser.add_argument('--width', type=int, default=960)
|
127 |
+
parser.add_argument('--height', type=int, default=540)
|
128 |
+
|
129 |
+
args = parser.parse_args()
|
130 |
+
|
131 |
+
results = evaluate(args.soccernet, args.prediction, args.width, args.height)
|
132 |
+
for key in results.keys():
|
133 |
+
print(f"{key}: {results[key]}")
|
sn_calibration/src/evaluate_camera.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from camera import Camera
|
11 |
+
from evaluate_extremities import scale_points, distance, mirror_labels
|
12 |
+
from soccerpitch import SoccerPitch
|
13 |
+
|
14 |
+
|
15 |
+
def get_polylines(camera_annotation, width, height, sampling_factor=0.2):
|
16 |
+
"""
|
17 |
+
Given a set of camera parameters, this function adapts the camera to the desired image resolution and then
|
18 |
+
projects the 3D points belonging to the terrain model in order to give a dictionary associating the classes
|
19 |
+
observed and the points projected in the image.
|
20 |
+
|
21 |
+
:param camera_annotation: camera parameters in their json/dictionary format
|
22 |
+
:param width: image width for evaluation
|
23 |
+
:param height: image height for evaluation
|
24 |
+
:return: a dictionary with keys corresponding to a class observed in the image ( a line of the 3D model whose
|
25 |
+
projection falls in the image) and values are then the list of 2D projected points.
|
26 |
+
"""
|
27 |
+
|
28 |
+
cam = Camera(width, height)
|
29 |
+
cam.from_json_parameters(camera_annotation)
|
30 |
+
field = SoccerPitch()
|
31 |
+
projections = dict()
|
32 |
+
sides = [
|
33 |
+
np.array([1, 0, 0]),
|
34 |
+
np.array([1, 0, -width + 1]),
|
35 |
+
np.array([0, 1, 0]),
|
36 |
+
np.array([0, 1, -height + 1])
|
37 |
+
]
|
38 |
+
for key, points in field.sample_field_points(sampling_factor).items():
|
39 |
+
projections_list = []
|
40 |
+
in_img = False
|
41 |
+
prev_proj = np.zeros(3)
|
42 |
+
for i, point in enumerate(points):
|
43 |
+
ext = cam.project_point(point)
|
44 |
+
if ext[2] < 1e-5:
|
45 |
+
# point at infinity or behind camera
|
46 |
+
continue
|
47 |
+
if 0 <= ext[0] < width and 0 <= ext[1] < height:
|
48 |
+
|
49 |
+
if not in_img and i > 0:
|
50 |
+
|
51 |
+
line = np.cross(ext, prev_proj)
|
52 |
+
in_img_intersections = []
|
53 |
+
dist_to_ext = []
|
54 |
+
for side in sides:
|
55 |
+
intersection = np.cross(line, side)
|
56 |
+
intersection /= intersection[2]
|
57 |
+
if 0 <= intersection[0] < width and 0 <= intersection[1] < height:
|
58 |
+
in_img_intersections.append(intersection)
|
59 |
+
dist_to_ext.append(np.sqrt(np.sum(np.square(intersection - ext))))
|
60 |
+
if in_img_intersections:
|
61 |
+
intersection = in_img_intersections[np.argmin(dist_to_ext)]
|
62 |
+
|
63 |
+
projections_list.append(
|
64 |
+
{
|
65 |
+
"x": intersection[0],
|
66 |
+
"y": intersection[1]
|
67 |
+
}
|
68 |
+
)
|
69 |
+
|
70 |
+
projections_list.append(
|
71 |
+
{
|
72 |
+
"x": ext[0],
|
73 |
+
"y": ext[1]
|
74 |
+
}
|
75 |
+
)
|
76 |
+
in_img = True
|
77 |
+
elif in_img:
|
78 |
+
# first point out
|
79 |
+
line = np.cross(ext, prev_proj)
|
80 |
+
|
81 |
+
in_img_intersections = []
|
82 |
+
dist_to_ext = []
|
83 |
+
for side in sides:
|
84 |
+
intersection = np.cross(line, side)
|
85 |
+
intersection /= intersection[2]
|
86 |
+
if 0 <= intersection[0] < width and 0 <= intersection[1] < height:
|
87 |
+
in_img_intersections.append(intersection)
|
88 |
+
dist_to_ext.append(np.sqrt(np.sum(np.square(intersection - ext))))
|
89 |
+
if in_img_intersections:
|
90 |
+
intersection = in_img_intersections[np.argmin(dist_to_ext)]
|
91 |
+
|
92 |
+
projections_list.append(
|
93 |
+
{
|
94 |
+
"x": intersection[0],
|
95 |
+
"y": intersection[1]
|
96 |
+
}
|
97 |
+
)
|
98 |
+
in_img = False
|
99 |
+
prev_proj = ext
|
100 |
+
if len(projections_list):
|
101 |
+
projections[key] = projections_list
|
102 |
+
return projections
|
103 |
+
|
104 |
+
|
105 |
+
def distance_to_polyline(point, polyline):
|
106 |
+
"""
|
107 |
+
Computes euclidian distance between a point and a polyline.
|
108 |
+
:param point: 2D point
|
109 |
+
:param polyline: a list of 2D point
|
110 |
+
:return: the distance value
|
111 |
+
"""
|
112 |
+
if 0 < len(polyline) < 2:
|
113 |
+
dist = distance(point, polyline[0])
|
114 |
+
return dist
|
115 |
+
else:
|
116 |
+
dist_to_segments = []
|
117 |
+
point_np = np.array([point["x"], point["y"], 1])
|
118 |
+
|
119 |
+
for i in range(len(polyline) - 1):
|
120 |
+
origin_segment = np.array([
|
121 |
+
polyline[i]["x"],
|
122 |
+
polyline[i]["y"],
|
123 |
+
1
|
124 |
+
])
|
125 |
+
end_segment = np.array([
|
126 |
+
polyline[i + 1]["x"],
|
127 |
+
polyline[i + 1]["y"],
|
128 |
+
1
|
129 |
+
])
|
130 |
+
line = np.cross(origin_segment, end_segment)
|
131 |
+
line /= np.sqrt(np.square(line[0]) + np.square(line[1]))
|
132 |
+
|
133 |
+
# project point on line l
|
134 |
+
projected = np.cross((np.cross(np.array([line[0], line[1], 0]), point_np)), line)
|
135 |
+
projected = projected / projected[2]
|
136 |
+
|
137 |
+
v1 = projected - origin_segment
|
138 |
+
v2 = end_segment - origin_segment
|
139 |
+
k = np.dot(v1, v2) / np.dot(v2, v2)
|
140 |
+
if 0 < k < 1:
|
141 |
+
|
142 |
+
segment_distance = np.sqrt(np.sum(np.square(projected - point_np)))
|
143 |
+
else:
|
144 |
+
d1 = distance(point, polyline[i])
|
145 |
+
d2 = distance(point, polyline[i + 1])
|
146 |
+
segment_distance = np.min([d1, d2])
|
147 |
+
|
148 |
+
dist_to_segments.append(segment_distance)
|
149 |
+
return np.min(dist_to_segments)
|
150 |
+
|
151 |
+
|
152 |
+
def evaluate_camera_prediction(projected_lines, groundtruth_lines, threshold):
|
153 |
+
"""
|
154 |
+
Computes confusion matrices for a level of precision specified by the threshold.
|
155 |
+
A groundtruth line is correctly classified if it lies at less than threshold pixels from a line of the prediction
|
156 |
+
of the same class.
|
157 |
+
Computes also the reprojection error of each groundtruth point : the reprojection error is the L2 distance between
|
158 |
+
the point and the projection of the line.
|
159 |
+
:param projected_lines: dictionary of detected lines classes as keys and associated predicted points as values
|
160 |
+
:param groundtruth_lines: dictionary of annotated lines classes as keys and associated annotated points as values
|
161 |
+
:param threshold: distance in pixels that distinguishes good matches from bad ones
|
162 |
+
:return: confusion matrix, per class confusion matrix & per class reprojection errors
|
163 |
+
"""
|
164 |
+
global_confusion_mat = np.zeros((2, 2), dtype=np.float32)
|
165 |
+
per_class_confusion = {}
|
166 |
+
dict_errors = {}
|
167 |
+
detected_classes = set(projected_lines.keys())
|
168 |
+
groundtruth_classes = set(groundtruth_lines.keys())
|
169 |
+
|
170 |
+
false_positives_classes = detected_classes - groundtruth_classes
|
171 |
+
for false_positive_class in false_positives_classes:
|
172 |
+
# false_positives = len(projected_lines[false_positive_class])
|
173 |
+
if "Circle" not in false_positive_class:
|
174 |
+
# Count only extremities for lines, independently of soccer pitch sampling
|
175 |
+
false_positives = 2.
|
176 |
+
else:
|
177 |
+
false_positives = 9.
|
178 |
+
per_class_confusion[false_positive_class] = np.array([[0., false_positives], [0., 0.]])
|
179 |
+
global_confusion_mat[0, 1] += 1
|
180 |
+
|
181 |
+
false_negatives_classes = groundtruth_classes - detected_classes
|
182 |
+
for false_negatives_class in false_negatives_classes:
|
183 |
+
false_negatives = len(groundtruth_lines[false_negatives_class])
|
184 |
+
per_class_confusion[false_negatives_class] = np.array([[0., 0.], [false_negatives, 0.]])
|
185 |
+
global_confusion_mat[1, 0] += 1
|
186 |
+
|
187 |
+
common_classes = detected_classes - false_positives_classes
|
188 |
+
|
189 |
+
for detected_class in common_classes:
|
190 |
+
|
191 |
+
detected_points = projected_lines[detected_class]
|
192 |
+
groundtruth_points = groundtruth_lines[detected_class]
|
193 |
+
|
194 |
+
per_class_confusion[detected_class] = np.zeros((2, 2))
|
195 |
+
|
196 |
+
all_below_dist = 1
|
197 |
+
for point in groundtruth_points:
|
198 |
+
|
199 |
+
dist_to_poly = distance_to_polyline(point, detected_points)
|
200 |
+
if dist_to_poly < threshold:
|
201 |
+
per_class_confusion[detected_class][0, 0] += 1
|
202 |
+
else:
|
203 |
+
per_class_confusion[detected_class][0, 1] += 1
|
204 |
+
all_below_dist *= 0
|
205 |
+
|
206 |
+
if detected_class in dict_errors.keys():
|
207 |
+
dict_errors[detected_class].append(dist_to_poly)
|
208 |
+
else:
|
209 |
+
dict_errors[detected_class] = [dist_to_poly]
|
210 |
+
|
211 |
+
if all_below_dist:
|
212 |
+
global_confusion_mat[0, 0] += 1
|
213 |
+
else:
|
214 |
+
global_confusion_mat[0, 1] += 1
|
215 |
+
|
216 |
+
return global_confusion_mat, per_class_confusion, dict_errors
|
217 |
+
|
218 |
+
|
219 |
+
if __name__ == "__main__":
|
220 |
+
|
221 |
+
parser = argparse.ArgumentParser(description='Evaluation camera calibration task')
|
222 |
+
|
223 |
+
parser.add_argument('-s', '--soccernet', default="/home/fmg/data/SN23/calibration-2023-bis/", type=str,
|
224 |
+
help='Path to the SoccerNet-V3 dataset folder')
|
225 |
+
parser.add_argument('-p', '--prediction', default="/home/fmg/results/SN23-tests/",
|
226 |
+
required=False, type=str,
|
227 |
+
help="Path to the prediction folder")
|
228 |
+
parser.add_argument('-t', '--threshold', default=5, required=False, type=int,
|
229 |
+
help="Accuracy threshold in pixels")
|
230 |
+
parser.add_argument('--split', required=False, type=str, default="valid", help='Select the split of data')
|
231 |
+
parser.add_argument('--resolution_width', required=False, type=int, default=960,
|
232 |
+
help='width resolution of the images')
|
233 |
+
parser.add_argument('--resolution_height', required=False, type=int, default=540,
|
234 |
+
help='height resolution of the images')
|
235 |
+
args = parser.parse_args()
|
236 |
+
|
237 |
+
accuracies = []
|
238 |
+
precisions = []
|
239 |
+
recalls = []
|
240 |
+
dict_errors = {}
|
241 |
+
per_class_confusion_dict = {}
|
242 |
+
|
243 |
+
dataset_dir = os.path.join(args.soccernet, args.split)
|
244 |
+
if not os.path.exists(dataset_dir):
|
245 |
+
print("Invalid dataset path !")
|
246 |
+
exit(-1)
|
247 |
+
|
248 |
+
annotation_files = [f for f in os.listdir(dataset_dir) if ".json" in f]
|
249 |
+
|
250 |
+
missed, total_frames = 0, 0
|
251 |
+
with tqdm(enumerate(annotation_files), total=len(annotation_files), ncols=160) as t:
|
252 |
+
for i, annotation_file in t:
|
253 |
+
frame_index = annotation_file.split(".")[0]
|
254 |
+
annotation_file = os.path.join(args.soccernet, args.split, annotation_file)
|
255 |
+
prediction_file = os.path.join(args.prediction, args.split, f"camera_{frame_index}.json")
|
256 |
+
|
257 |
+
total_frames += 1
|
258 |
+
|
259 |
+
if not os.path.exists(prediction_file):
|
260 |
+
missed += 1
|
261 |
+
|
262 |
+
continue
|
263 |
+
|
264 |
+
with open(annotation_file, 'r') as f:
|
265 |
+
line_annotations = json.load(f)
|
266 |
+
|
267 |
+
with open(prediction_file, 'r') as f:
|
268 |
+
predictions = json.load(f)
|
269 |
+
|
270 |
+
line_annotations = scale_points(line_annotations, args.resolution_width, args.resolution_height)
|
271 |
+
|
272 |
+
image_path = os.path.join(args.soccernet, args.split, f"{frame_index}.jpg")
|
273 |
+
|
274 |
+
img_groundtruth = line_annotations
|
275 |
+
|
276 |
+
img_prediction = get_polylines(predictions, args.resolution_width, args.resolution_height,
|
277 |
+
sampling_factor=0.9)
|
278 |
+
|
279 |
+
confusion1, per_class_conf1, reproj_errors1 = evaluate_camera_prediction(img_prediction,
|
280 |
+
img_groundtruth,
|
281 |
+
args.threshold)
|
282 |
+
|
283 |
+
confusion2, per_class_conf2, reproj_errors2 = evaluate_camera_prediction(img_prediction,
|
284 |
+
mirror_labels(img_groundtruth),
|
285 |
+
args.threshold)
|
286 |
+
|
287 |
+
accuracy1, accuracy2 = 0., 0.
|
288 |
+
if confusion1.sum() > 0:
|
289 |
+
accuracy1 = confusion1[0, 0] / confusion1.sum()
|
290 |
+
|
291 |
+
if confusion2.sum() > 0:
|
292 |
+
accuracy2 = confusion2[0, 0] / confusion2.sum()
|
293 |
+
|
294 |
+
if accuracy1 > accuracy2:
|
295 |
+
accuracy = accuracy1
|
296 |
+
confusion = confusion1
|
297 |
+
per_class_conf = per_class_conf1
|
298 |
+
reproj_errors = reproj_errors1
|
299 |
+
else:
|
300 |
+
accuracy = accuracy2
|
301 |
+
confusion = confusion2
|
302 |
+
per_class_conf = per_class_conf2
|
303 |
+
reproj_errors = reproj_errors2
|
304 |
+
|
305 |
+
accuracies.append(accuracy)
|
306 |
+
if confusion[0, :].sum() > 0:
|
307 |
+
precision = confusion[0, 0] / (confusion[0, :].sum())
|
308 |
+
precisions.append(precision)
|
309 |
+
if (confusion[0, 0] + confusion[1, 0]) > 0:
|
310 |
+
recall = confusion[0, 0] / (confusion[0, 0] + confusion[1, 0])
|
311 |
+
recalls.append(recall)
|
312 |
+
|
313 |
+
for line_class, errors in reproj_errors.items():
|
314 |
+
if line_class in dict_errors.keys():
|
315 |
+
dict_errors[line_class].extend(errors)
|
316 |
+
else:
|
317 |
+
dict_errors[line_class] = errors
|
318 |
+
|
319 |
+
for line_class, confusion_mat in per_class_conf.items():
|
320 |
+
if line_class in per_class_confusion_dict.keys():
|
321 |
+
per_class_confusion_dict[line_class] += confusion_mat
|
322 |
+
else:
|
323 |
+
per_class_confusion_dict[line_class] = confusion_mat
|
324 |
+
|
325 |
+
completeness_score = (total_frames - missed) / total_frames
|
326 |
+
mAccuracy = np.mean(accuracies)
|
327 |
+
|
328 |
+
final_score = completeness_score * mAccuracy
|
329 |
+
print(f" On SoccerNet {args.split} set, final score of : {final_score}")
|
330 |
+
print(f" On SoccerNet {args.split} set, completeness rate of : {completeness_score}")
|
331 |
+
|
332 |
+
mRecall = np.mean(recalls)
|
333 |
+
sRecall = np.std(recalls)
|
334 |
+
medianRecall = np.median(recalls)
|
335 |
+
print(
|
336 |
+
f" On SoccerNet {args.split} set, recall mean value : {mRecall * 100:2.2f}% with standard deviation of {sRecall * 100:2.2f}% and median of {medianRecall * 100:2.2f}%")
|
337 |
+
|
338 |
+
mPrecision = np.mean(precisions)
|
339 |
+
sPrecision = np.std(precisions)
|
340 |
+
medianPrecision = np.median(precisions)
|
341 |
+
print(
|
342 |
+
f" On SoccerNet {args.split} set, precision mean value : {mPrecision * 100:2.2f}% with standard deviation of {sPrecision * 100:2.2f}% and median of {medianPrecision * 100:2.2f}%")
|
343 |
+
|
344 |
+
sAccuracy = np.std(accuracies)
|
345 |
+
medianAccuracy = np.median(accuracies)
|
346 |
+
print(
|
347 |
+
f" On SoccerNet {args.split} set, accuracy mean value : {mAccuracy * 100:2.2f}% with standard deviation of {sAccuracy * 100:2.2f}% and median of {medianAccuracy * 100:2.2f}%")
|
348 |
+
|
349 |
+
print()
|
350 |
+
|
351 |
+
for line_class, confusion_mat in per_class_confusion_dict.items():
|
352 |
+
class_accuracy = confusion_mat[0, 0] / confusion_mat.sum()
|
353 |
+
class_recall = confusion_mat[0, 0] / (confusion_mat[0, 0] + confusion_mat[1, 0])
|
354 |
+
class_precision = confusion_mat[0, 0] / (confusion_mat[0, 0] + confusion_mat[0, 1])
|
355 |
+
print(
|
356 |
+
f"For class {line_class}, accuracy of {class_accuracy * 100:2.2f}%, precision of {class_precision * 100:2.2f}% and recall of {class_recall * 100:2.2f}%")
|
357 |
+
|
358 |
+
for k, v in dict_errors.items():
|
359 |
+
fig, ax1 = plt.subplots(figsize=(11, 8))
|
360 |
+
ax1.hist(v, bins=30, range=(0, 60))
|
361 |
+
ax1.set_title(k)
|
362 |
+
ax1.set_xlabel("Errors in pixel")
|
363 |
+
os.makedirs(f"./results/", exist_ok=True)
|
364 |
+
plt.savefig(f"./results/{k}_reprojection_error.png")
|
365 |
+
plt.close(fig)
|
sn_calibration/src/evaluate_extremities.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
sys.path.append("sn_calibration")
|
11 |
+
sys.path.append("sn_calibration/src")
|
12 |
+
|
13 |
+
from soccerpitch import SoccerPitch
|
14 |
+
|
15 |
+
|
16 |
+
def distance(point1, point2):
|
17 |
+
"""
|
18 |
+
Computes euclidian distance between 2D points
|
19 |
+
:param point1
|
20 |
+
:param point2
|
21 |
+
:return: euclidian distance between point1 and point2
|
22 |
+
"""
|
23 |
+
diff = np.array([point1['x'], point1['y']]) - np.array([point2['x'], point2['y']])
|
24 |
+
sq_dist = np.square(diff)
|
25 |
+
return np.sqrt(sq_dist.sum())
|
26 |
+
|
27 |
+
|
28 |
+
def mirror_labels(lines_dict):
|
29 |
+
"""
|
30 |
+
Replace each line class key of the dictionary with its opposite element according to a central projection by the
|
31 |
+
soccer pitch center
|
32 |
+
:param lines_dict: dictionary whose keys will be mirrored
|
33 |
+
:return: Dictionary with mirrored keys and same values
|
34 |
+
"""
|
35 |
+
mirrored_dict = dict()
|
36 |
+
for line_class, value in lines_dict.items():
|
37 |
+
mirrored_dict[SoccerPitch.symetric_classes[line_class]] = value
|
38 |
+
return mirrored_dict
|
39 |
+
|
40 |
+
|
41 |
+
def evaluate_detection_prediction(detected_lines, groundtruth_lines, threshold=2.):
|
42 |
+
"""
|
43 |
+
Evaluates the prediction of extremities. The extremities associated to a class are unordered. The extremities of the
|
44 |
+
"Circle central" element is not well-defined for this task, thus this class is ignored.
|
45 |
+
Computes confusion matrices for a level of precision specified by the threshold.
|
46 |
+
A groundtruth extremity point is correctly classified if it lies at less than threshold pixels from the
|
47 |
+
corresponding extremity point of the prediction of the same class.
|
48 |
+
Computes also the euclidian distance between each predicted extremity and its closest groundtruth extremity, when
|
49 |
+
both the groundtruth and the prediction contain the element class.
|
50 |
+
|
51 |
+
:param detected_lines: dictionary of detected lines classes as keys and associated predicted extremities as values
|
52 |
+
:param groundtruth_lines: dictionary of annotated lines classes as keys and associated annotated points as values
|
53 |
+
:param threshold: distance in pixels that distinguishes good matches from bad ones
|
54 |
+
:return: confusion matrix, per class confusion matrix & per class localization errors
|
55 |
+
"""
|
56 |
+
confusion_mat = np.zeros((2, 2), dtype=np.float32)
|
57 |
+
per_class_confusion = {}
|
58 |
+
errors_dict = {}
|
59 |
+
detected_classes = set(detected_lines.keys())
|
60 |
+
groundtruth_classes = set(groundtruth_lines.keys())
|
61 |
+
|
62 |
+
if "Circle central" in groundtruth_classes:
|
63 |
+
groundtruth_classes.remove("Circle central")
|
64 |
+
if "Circle central" in detected_classes:
|
65 |
+
detected_classes.remove("Circle central")
|
66 |
+
|
67 |
+
false_positives_classes = detected_classes - groundtruth_classes
|
68 |
+
for false_positive_class in false_positives_classes:
|
69 |
+
false_positives = len(detected_lines[false_positive_class])
|
70 |
+
confusion_mat[0, 1] += false_positives
|
71 |
+
per_class_confusion[false_positive_class] = np.array([[0., false_positives], [0., 0.]])
|
72 |
+
|
73 |
+
false_negatives_classes = groundtruth_classes - detected_classes
|
74 |
+
for false_negatives_class in false_negatives_classes:
|
75 |
+
false_negatives = len(groundtruth_lines[false_negatives_class])
|
76 |
+
confusion_mat[1, 0] += false_negatives
|
77 |
+
per_class_confusion[false_negatives_class] = np.array([[0., 0.], [false_negatives, 0.]])
|
78 |
+
|
79 |
+
common_classes = detected_classes - false_positives_classes
|
80 |
+
|
81 |
+
for detected_class in common_classes:
|
82 |
+
|
83 |
+
detected_points = detected_lines[detected_class]
|
84 |
+
|
85 |
+
groundtruth_points = groundtruth_lines[detected_class]
|
86 |
+
|
87 |
+
groundtruth_extremities = [groundtruth_points[0], groundtruth_points[-1]]
|
88 |
+
predicted_extremities = [detected_points[0], detected_points[-1]]
|
89 |
+
per_class_confusion[detected_class] = np.zeros((2, 2))
|
90 |
+
|
91 |
+
dist1 = distance(groundtruth_extremities[0], predicted_extremities[0])
|
92 |
+
dist1rev = distance(groundtruth_extremities[1], predicted_extremities[0])
|
93 |
+
|
94 |
+
dist2 = distance(groundtruth_extremities[1], predicted_extremities[1])
|
95 |
+
dist2rev = distance(groundtruth_extremities[0], predicted_extremities[1])
|
96 |
+
if dist1rev <= dist1 and dist2rev <= dist2:
|
97 |
+
# reverse order
|
98 |
+
dist1 = dist1rev
|
99 |
+
dist2 = dist2rev
|
100 |
+
|
101 |
+
errors_dict[detected_class] = [dist1, dist2]
|
102 |
+
|
103 |
+
if dist1 < threshold:
|
104 |
+
confusion_mat[0, 0] += 1
|
105 |
+
per_class_confusion[detected_class][0, 0] += 1
|
106 |
+
else:
|
107 |
+
# treat too far detections as false positives
|
108 |
+
confusion_mat[0, 1] += 1
|
109 |
+
per_class_confusion[detected_class][0, 1] += 1
|
110 |
+
|
111 |
+
if dist2 < threshold:
|
112 |
+
confusion_mat[0, 0] += 1
|
113 |
+
per_class_confusion[detected_class][0, 0] += 1
|
114 |
+
|
115 |
+
else:
|
116 |
+
# treat too far detections as false positives
|
117 |
+
confusion_mat[0, 1] += 1
|
118 |
+
per_class_confusion[detected_class][0, 1] += 1
|
119 |
+
|
120 |
+
return confusion_mat, per_class_confusion, errors_dict
|
121 |
+
|
122 |
+
|
123 |
+
def scale_points(points_dict, s_width, s_height):
|
124 |
+
"""
|
125 |
+
Scale points by s_width and s_height factors
|
126 |
+
:param points_dict: dictionary of annotations/predictions with normalized point values
|
127 |
+
:param s_width: width scaling factor
|
128 |
+
:param s_height: height scaling factor
|
129 |
+
:return: dictionary with scaled points
|
130 |
+
"""
|
131 |
+
line_dict = {}
|
132 |
+
for line_class, points in points_dict.items():
|
133 |
+
scaled_points = []
|
134 |
+
for point in points:
|
135 |
+
new_point = {'x': point['x'] * (s_width-1), 'y': point['y'] * (s_height-1)}
|
136 |
+
scaled_points.append(new_point)
|
137 |
+
if len(scaled_points):
|
138 |
+
line_dict[line_class] = scaled_points
|
139 |
+
return line_dict
|
140 |
+
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
|
144 |
+
parser = argparse.ArgumentParser(description='Test')
|
145 |
+
|
146 |
+
parser.add_argument('-s', '--soccernet', default="./annotations", type=str,
|
147 |
+
help='Path to the SoccerNet-V3 dataset folder')
|
148 |
+
parser.add_argument('-p', '--prediction', default="./results_bis",
|
149 |
+
required=False, type=str,
|
150 |
+
help="Path to the prediction folder")
|
151 |
+
parser.add_argument('-t', '--threshold', default=10, required=False, type=int,
|
152 |
+
help="Accuracy threshold in pixels")
|
153 |
+
parser.add_argument('--split', required=False, type=str, default="test", help='Select the split of data')
|
154 |
+
parser.add_argument('--resolution_width', required=False, type=int, default=960,
|
155 |
+
help='width resolution of the images')
|
156 |
+
parser.add_argument('--resolution_height', required=False, type=int, default=540,
|
157 |
+
help='height resolution of the images')
|
158 |
+
args = parser.parse_args()
|
159 |
+
|
160 |
+
accuracies = []
|
161 |
+
precisions = []
|
162 |
+
recalls = []
|
163 |
+
dict_errors = {}
|
164 |
+
per_class_confusion_dict = {}
|
165 |
+
|
166 |
+
dataset_dir = os.path.join(args.soccernet, args.split)
|
167 |
+
if not os.path.exists(dataset_dir):
|
168 |
+
print("Invalid dataset path !")
|
169 |
+
exit(-1)
|
170 |
+
|
171 |
+
annotation_files = [f for f in os.listdir(dataset_dir) if ".json" in f]
|
172 |
+
|
173 |
+
with tqdm(enumerate(annotation_files), total=len(annotation_files), ncols=160) as t:
|
174 |
+
for i, annotation_file in t:
|
175 |
+
frame_index = annotation_file.split(".")[0]
|
176 |
+
annotation_file = os.path.join(args.soccernet, args.split, annotation_file)
|
177 |
+
prediction_file = os.path.join(args.prediction, args.split, f"extremities_{frame_index}.json")
|
178 |
+
|
179 |
+
if not os.path.exists(prediction_file):
|
180 |
+
accuracies.append(0.)
|
181 |
+
precisions.append(0.)
|
182 |
+
recalls.append(0.)
|
183 |
+
continue
|
184 |
+
|
185 |
+
with open(annotation_file, 'r') as f:
|
186 |
+
line_annotations = json.load(f)
|
187 |
+
|
188 |
+
with open(prediction_file, 'r') as f:
|
189 |
+
predictions = json.load(f)
|
190 |
+
|
191 |
+
predictions = scale_points(predictions, args.resolution_width, args.resolution_height)
|
192 |
+
line_annotations = scale_points(line_annotations, args.resolution_width, args.resolution_height)
|
193 |
+
|
194 |
+
img_prediction = predictions
|
195 |
+
img_groundtruth = line_annotations
|
196 |
+
confusion1, per_class_conf1, reproj_errors1 = evaluate_detection_prediction(img_prediction,
|
197 |
+
img_groundtruth,
|
198 |
+
args.threshold)
|
199 |
+
confusion2, per_class_conf2, reproj_errors2 = evaluate_detection_prediction(img_prediction,
|
200 |
+
mirror_labels(
|
201 |
+
img_groundtruth),
|
202 |
+
args.threshold)
|
203 |
+
|
204 |
+
accuracy1, accuracy2 = 0., 0.
|
205 |
+
if confusion1.sum() > 0:
|
206 |
+
accuracy1 = confusion1[0, 0] / confusion1.sum()
|
207 |
+
|
208 |
+
if confusion2.sum() > 0:
|
209 |
+
accuracy2 = confusion2[0, 0] / confusion2.sum()
|
210 |
+
|
211 |
+
if accuracy1 > accuracy2:
|
212 |
+
accuracy = accuracy1
|
213 |
+
confusion = confusion1
|
214 |
+
per_class_conf = per_class_conf1
|
215 |
+
reproj_errors = reproj_errors1
|
216 |
+
else:
|
217 |
+
accuracy = accuracy2
|
218 |
+
confusion = confusion2
|
219 |
+
per_class_conf = per_class_conf2
|
220 |
+
reproj_errors = reproj_errors2
|
221 |
+
|
222 |
+
accuracies.append(accuracy)
|
223 |
+
if confusion[0, :].sum() > 0:
|
224 |
+
precision = confusion[0, 0] / (confusion[0, :].sum())
|
225 |
+
precisions.append(precision)
|
226 |
+
if (confusion[0, 0] + confusion[1, 0]) > 0:
|
227 |
+
recall = confusion[0, 0] / (confusion[0, 0] + confusion[1, 0])
|
228 |
+
recalls.append(recall)
|
229 |
+
|
230 |
+
for line_class, errors in reproj_errors.items():
|
231 |
+
if line_class in dict_errors.keys():
|
232 |
+
dict_errors[line_class].extend(errors)
|
233 |
+
else:
|
234 |
+
dict_errors[line_class] = errors
|
235 |
+
|
236 |
+
for line_class, confusion_mat in per_class_conf.items():
|
237 |
+
if line_class in per_class_confusion_dict.keys():
|
238 |
+
per_class_confusion_dict[line_class] += confusion_mat
|
239 |
+
else:
|
240 |
+
per_class_confusion_dict[line_class] = confusion_mat
|
241 |
+
|
242 |
+
mRecall = np.mean(recalls)
|
243 |
+
sRecall = np.std(recalls)
|
244 |
+
medianRecall = np.median(recalls)
|
245 |
+
print(
|
246 |
+
f" On SoccerNet {args.split} set, recall mean value : {mRecall * 100:2.2f}% with standard deviation of {sRecall * 100:2.2f}% and median of {medianRecall * 100:2.2f}%")
|
247 |
+
|
248 |
+
mPrecision = np.mean(precisions)
|
249 |
+
sPrecision = np.std(precisions)
|
250 |
+
medianPrecision = np.median(precisions)
|
251 |
+
print(
|
252 |
+
f" On SoccerNet {args.split} set, precision mean value : {mPrecision * 100:2.2f}% with standard deviation of {sPrecision * 100:2.2f}% and median of {medianPrecision * 100:2.2f}%")
|
253 |
+
|
254 |
+
mAccuracy = np.mean(accuracies)
|
255 |
+
sAccuracy = np.std(accuracies)
|
256 |
+
medianAccuracy = np.median(accuracies)
|
257 |
+
print(
|
258 |
+
f" On SoccerNet {args.split} set, accuracy mean value : {mAccuracy * 100:2.2f}% with standard deviation of {sAccuracy * 100:2.2f}% and median of {medianAccuracy * 100:2.2f}%")
|
259 |
+
|
260 |
+
for line_class, confusion_mat in per_class_confusion_dict.items():
|
261 |
+
class_accuracy = confusion_mat[0, 0] / confusion_mat.sum()
|
262 |
+
class_recall = confusion_mat[0, 0] / (confusion_mat[0, 0] + confusion_mat[1, 0])
|
263 |
+
class_precision = confusion_mat[0, 0] / (confusion_mat[0, 0] + confusion_mat[0, 1])
|
264 |
+
print(
|
265 |
+
f"For class {line_class}, accuracy of {class_accuracy * 100:2.2f}%, precision of {class_precision * 100:2.2f}% and recall of {class_recall * 100:2.2f}%")
|
266 |
+
|
267 |
+
for k, v in dict_errors.items():
|
268 |
+
fig, ax1 = plt.subplots(figsize=(11, 8))
|
269 |
+
ax1.hist(v, bins=30, range=(0, 60))
|
270 |
+
ax1.set_title(k)
|
271 |
+
ax1.set_xlabel("Errors in pixel")
|
272 |
+
os.makedirs(f"./results/", exist_ok=True)
|
273 |
+
plt.savefig(f"./results/{k}_detection_error.png")
|
274 |
+
plt.close(fig)
|
sn_calibration/src/soccerpitch.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class SoccerPitch:
|
5 |
+
"""Static class variables that are specified by the rules of the game """
|
6 |
+
GOAL_LINE_TO_PENALTY_MARK = 11.0
|
7 |
+
PENALTY_AREA_WIDTH = 40.32
|
8 |
+
PENALTY_AREA_LENGTH = 16.5
|
9 |
+
GOAL_AREA_WIDTH = 18.32
|
10 |
+
GOAL_AREA_LENGTH = 5.5
|
11 |
+
CENTER_CIRCLE_RADIUS = 9.15
|
12 |
+
GOAL_HEIGHT = 2.44
|
13 |
+
GOAL_LENGTH = 7.32
|
14 |
+
|
15 |
+
lines_classes = [
|
16 |
+
'Big rect. left bottom',
|
17 |
+
'Big rect. left main',
|
18 |
+
'Big rect. left top',
|
19 |
+
'Big rect. right bottom',
|
20 |
+
'Big rect. right main',
|
21 |
+
'Big rect. right top',
|
22 |
+
'Circle central',
|
23 |
+
'Circle left',
|
24 |
+
'Circle right',
|
25 |
+
'Goal left crossbar',
|
26 |
+
'Goal left post left ',
|
27 |
+
'Goal left post right',
|
28 |
+
'Goal right crossbar',
|
29 |
+
'Goal right post left',
|
30 |
+
'Goal right post right',
|
31 |
+
'Goal unknown',
|
32 |
+
'Line unknown',
|
33 |
+
'Middle line',
|
34 |
+
'Side line bottom',
|
35 |
+
'Side line left',
|
36 |
+
'Side line right',
|
37 |
+
'Side line top',
|
38 |
+
'Small rect. left bottom',
|
39 |
+
'Small rect. left main',
|
40 |
+
'Small rect. left top',
|
41 |
+
'Small rect. right bottom',
|
42 |
+
'Small rect. right main',
|
43 |
+
'Small rect. right top'
|
44 |
+
]
|
45 |
+
|
46 |
+
symetric_classes = {
|
47 |
+
'Side line top': 'Side line bottom',
|
48 |
+
'Side line bottom': 'Side line top',
|
49 |
+
'Side line left': 'Side line right',
|
50 |
+
'Middle line': 'Middle line',
|
51 |
+
'Side line right': 'Side line left',
|
52 |
+
'Big rect. left top': 'Big rect. right bottom',
|
53 |
+
'Big rect. left bottom': 'Big rect. right top',
|
54 |
+
'Big rect. left main': 'Big rect. right main',
|
55 |
+
'Big rect. right top': 'Big rect. left bottom',
|
56 |
+
'Big rect. right bottom': 'Big rect. left top',
|
57 |
+
'Big rect. right main': 'Big rect. left main',
|
58 |
+
'Small rect. left top': 'Small rect. right bottom',
|
59 |
+
'Small rect. left bottom': 'Small rect. right top',
|
60 |
+
'Small rect. left main': 'Small rect. right main',
|
61 |
+
'Small rect. right top': 'Small rect. left bottom',
|
62 |
+
'Small rect. right bottom': 'Small rect. left top',
|
63 |
+
'Small rect. right main': 'Small rect. left main',
|
64 |
+
'Circle left': 'Circle right',
|
65 |
+
'Circle central': 'Circle central',
|
66 |
+
'Circle right': 'Circle left',
|
67 |
+
'Goal left crossbar': 'Goal right crossbar',
|
68 |
+
'Goal left post left ': 'Goal right post left',
|
69 |
+
'Goal left post right': 'Goal right post right',
|
70 |
+
'Goal right crossbar': 'Goal left crossbar',
|
71 |
+
'Goal right post left': 'Goal left post left ',
|
72 |
+
'Goal right post right': 'Goal left post right',
|
73 |
+
'Goal unknown': 'Goal unknown',
|
74 |
+
'Line unknown': 'Line unknown'
|
75 |
+
}
|
76 |
+
|
77 |
+
# RGB values
|
78 |
+
palette = {
|
79 |
+
'Big rect. left bottom': (127, 0, 0),
|
80 |
+
'Big rect. left main': (102, 102, 102),
|
81 |
+
'Big rect. left top': (0, 0, 127),
|
82 |
+
'Big rect. right bottom': (86, 32, 39),
|
83 |
+
'Big rect. right main': (48, 77, 0),
|
84 |
+
'Big rect. right top': (14, 97, 100),
|
85 |
+
'Circle central': (0, 0, 255),
|
86 |
+
'Circle left': (255, 127, 0),
|
87 |
+
'Circle right': (0, 255, 255),
|
88 |
+
'Goal left crossbar': (255, 255, 200),
|
89 |
+
'Goal left post left ': (165, 255, 0),
|
90 |
+
'Goal left post right': (155, 119, 45),
|
91 |
+
'Goal right crossbar': (86, 32, 139),
|
92 |
+
'Goal right post left': (196, 120, 153),
|
93 |
+
'Goal right post right': (166, 36, 52),
|
94 |
+
'Goal unknown': (0, 0, 0),
|
95 |
+
'Line unknown': (0, 0, 0),
|
96 |
+
'Middle line': (255, 255, 0),
|
97 |
+
'Side line bottom': (255, 0, 255),
|
98 |
+
'Side line left': (0, 255, 150),
|
99 |
+
'Side line right': (0, 230, 0),
|
100 |
+
'Side line top': (230, 0, 0),
|
101 |
+
'Small rect. left bottom': (0, 150, 255),
|
102 |
+
'Small rect. left main': (254, 173, 225),
|
103 |
+
'Small rect. left top': (87, 72, 39),
|
104 |
+
'Small rect. right bottom': (122, 0, 255),
|
105 |
+
'Small rect. right main': (255, 255, 255),
|
106 |
+
'Small rect. right top': (153, 23, 153)
|
107 |
+
}
|
108 |
+
|
109 |
+
def __init__(self, pitch_length=105., pitch_width=68.):
|
110 |
+
"""
|
111 |
+
Initialize 3D coordinates of all elements of the soccer pitch.
|
112 |
+
:param pitch_length: According to FIFA rules, length belong to [90,120] meters
|
113 |
+
:param pitch_width: According to FIFA rules, length belong to [45,90] meters
|
114 |
+
"""
|
115 |
+
self.PITCH_LENGTH = pitch_length
|
116 |
+
self.PITCH_WIDTH = pitch_width
|
117 |
+
|
118 |
+
self.center_mark = np.array([0, 0, 0], dtype='float')
|
119 |
+
self.halfway_and_bottom_touch_line_mark = np.array([0, pitch_width / 2.0, 0], dtype='float')
|
120 |
+
self.halfway_and_top_touch_line_mark = np.array([0, -pitch_width / 2.0, 0], dtype='float')
|
121 |
+
self.halfway_line_and_center_circle_top_mark = np.array([0, -SoccerPitch.CENTER_CIRCLE_RADIUS, 0],
|
122 |
+
dtype='float')
|
123 |
+
self.halfway_line_and_center_circle_bottom_mark = np.array([0, SoccerPitch.CENTER_CIRCLE_RADIUS, 0],
|
124 |
+
dtype='float')
|
125 |
+
self.bottom_right_corner = np.array([pitch_length / 2.0, pitch_width / 2.0, 0], dtype='float')
|
126 |
+
self.bottom_left_corner = np.array([-pitch_length / 2.0, pitch_width / 2.0, 0], dtype='float')
|
127 |
+
self.top_right_corner = np.array([pitch_length / 2.0, -pitch_width / 2.0, 0], dtype='float')
|
128 |
+
self.top_left_corner = np.array([-pitch_length / 2.0, -pitch_width / 2.0, 0], dtype='float')
|
129 |
+
|
130 |
+
self.left_goal_bottom_left_post = np.array([-pitch_length / 2.0, SoccerPitch.GOAL_LENGTH / 2., 0.],
|
131 |
+
dtype='float')
|
132 |
+
self.left_goal_top_left_post = np.array(
|
133 |
+
[-pitch_length / 2.0, SoccerPitch.GOAL_LENGTH / 2., -SoccerPitch.GOAL_HEIGHT], dtype='float')
|
134 |
+
self.left_goal_bottom_right_post = np.array([-pitch_length / 2.0, -SoccerPitch.GOAL_LENGTH / 2., 0.],
|
135 |
+
dtype='float')
|
136 |
+
self.left_goal_top_right_post = np.array(
|
137 |
+
[-pitch_length / 2.0, -SoccerPitch.GOAL_LENGTH / 2., -SoccerPitch.GOAL_HEIGHT], dtype='float')
|
138 |
+
|
139 |
+
self.right_goal_bottom_left_post = np.array([pitch_length / 2.0, -SoccerPitch.GOAL_LENGTH / 2., 0.],
|
140 |
+
dtype='float')
|
141 |
+
self.right_goal_top_left_post = np.array(
|
142 |
+
[pitch_length / 2.0, -SoccerPitch.GOAL_LENGTH / 2., -SoccerPitch.GOAL_HEIGHT], dtype='float')
|
143 |
+
self.right_goal_bottom_right_post = np.array([pitch_length / 2.0, SoccerPitch.GOAL_LENGTH / 2., 0.],
|
144 |
+
dtype='float')
|
145 |
+
self.right_goal_top_right_post = np.array(
|
146 |
+
[pitch_length / 2.0, SoccerPitch.GOAL_LENGTH / 2., -SoccerPitch.GOAL_HEIGHT], dtype='float')
|
147 |
+
|
148 |
+
self.left_penalty_mark = np.array([-pitch_length / 2.0 + SoccerPitch.GOAL_LINE_TO_PENALTY_MARK, 0, 0],
|
149 |
+
dtype='float')
|
150 |
+
self.right_penalty_mark = np.array([pitch_length / 2.0 - SoccerPitch.GOAL_LINE_TO_PENALTY_MARK, 0, 0],
|
151 |
+
dtype='float')
|
152 |
+
|
153 |
+
self.left_penalty_area_top_right_corner = np.array(
|
154 |
+
[-pitch_length / 2.0 + SoccerPitch.PENALTY_AREA_LENGTH, -SoccerPitch.PENALTY_AREA_WIDTH / 2.0, 0],
|
155 |
+
dtype='float')
|
156 |
+
self.left_penalty_area_top_left_corner = np.array(
|
157 |
+
[-pitch_length / 2.0, -SoccerPitch.PENALTY_AREA_WIDTH / 2.0, 0],
|
158 |
+
dtype='float')
|
159 |
+
self.left_penalty_area_bottom_right_corner = np.array(
|
160 |
+
[-pitch_length / 2.0 + SoccerPitch.PENALTY_AREA_LENGTH, SoccerPitch.PENALTY_AREA_WIDTH / 2.0, 0],
|
161 |
+
dtype='float')
|
162 |
+
self.left_penalty_area_bottom_left_corner = np.array(
|
163 |
+
[-pitch_length / 2.0, SoccerPitch.PENALTY_AREA_WIDTH / 2.0, 0],
|
164 |
+
dtype='float')
|
165 |
+
self.right_penalty_area_top_right_corner = np.array(
|
166 |
+
[pitch_length / 2.0, -SoccerPitch.PENALTY_AREA_WIDTH / 2.0, 0],
|
167 |
+
dtype='float')
|
168 |
+
self.right_penalty_area_top_left_corner = np.array(
|
169 |
+
[pitch_length / 2.0 - SoccerPitch.PENALTY_AREA_LENGTH, -SoccerPitch.PENALTY_AREA_WIDTH / 2.0, 0],
|
170 |
+
dtype='float')
|
171 |
+
self.right_penalty_area_bottom_right_corner = np.array(
|
172 |
+
[pitch_length / 2.0, SoccerPitch.PENALTY_AREA_WIDTH / 2.0, 0],
|
173 |
+
dtype='float')
|
174 |
+
self.right_penalty_area_bottom_left_corner = np.array(
|
175 |
+
[pitch_length / 2.0 - SoccerPitch.PENALTY_AREA_LENGTH, SoccerPitch.PENALTY_AREA_WIDTH / 2.0, 0],
|
176 |
+
dtype='float')
|
177 |
+
|
178 |
+
self.left_goal_area_top_right_corner = np.array(
|
179 |
+
[-pitch_length / 2.0 + SoccerPitch.GOAL_AREA_LENGTH, -SoccerPitch.GOAL_AREA_WIDTH / 2.0, 0], dtype='float')
|
180 |
+
self.left_goal_area_top_left_corner = np.array([-pitch_length / 2.0, - SoccerPitch.GOAL_AREA_WIDTH / 2.0, 0],
|
181 |
+
dtype='float')
|
182 |
+
self.left_goal_area_bottom_right_corner = np.array(
|
183 |
+
[-pitch_length / 2.0 + SoccerPitch.GOAL_AREA_LENGTH, SoccerPitch.GOAL_AREA_WIDTH / 2.0, 0], dtype='float')
|
184 |
+
self.left_goal_area_bottom_left_corner = np.array([-pitch_length / 2.0, SoccerPitch.GOAL_AREA_WIDTH / 2.0, 0],
|
185 |
+
dtype='float')
|
186 |
+
self.right_goal_area_top_right_corner = np.array([pitch_length / 2.0, -SoccerPitch.GOAL_AREA_WIDTH / 2.0, 0],
|
187 |
+
dtype='float')
|
188 |
+
self.right_goal_area_top_left_corner = np.array(
|
189 |
+
[pitch_length / 2.0 - SoccerPitch.GOAL_AREA_LENGTH, -SoccerPitch.GOAL_AREA_WIDTH / 2.0, 0], dtype='float')
|
190 |
+
self.right_goal_area_bottom_right_corner = np.array([pitch_length / 2.0, SoccerPitch.GOAL_AREA_WIDTH / 2.0, 0],
|
191 |
+
dtype='float')
|
192 |
+
self.right_goal_area_bottom_left_corner = np.array(
|
193 |
+
[pitch_length / 2.0 - SoccerPitch.GOAL_AREA_LENGTH, SoccerPitch.GOAL_AREA_WIDTH / 2.0, 0], dtype='float')
|
194 |
+
|
195 |
+
x = -pitch_length / 2.0 + SoccerPitch.PENALTY_AREA_LENGTH;
|
196 |
+
dx = SoccerPitch.PENALTY_AREA_LENGTH - SoccerPitch.GOAL_LINE_TO_PENALTY_MARK;
|
197 |
+
y = -np.sqrt(SoccerPitch.CENTER_CIRCLE_RADIUS * SoccerPitch.CENTER_CIRCLE_RADIUS - dx * dx);
|
198 |
+
self.top_left_16M_penalty_arc_mark = np.array([x, y, 0], dtype='float')
|
199 |
+
|
200 |
+
x = pitch_length / 2.0 - SoccerPitch.PENALTY_AREA_LENGTH;
|
201 |
+
dx = SoccerPitch.PENALTY_AREA_LENGTH - SoccerPitch.GOAL_LINE_TO_PENALTY_MARK;
|
202 |
+
y = -np.sqrt(SoccerPitch.CENTER_CIRCLE_RADIUS * SoccerPitch.CENTER_CIRCLE_RADIUS - dx * dx);
|
203 |
+
self.top_right_16M_penalty_arc_mark = np.array([x, y, 0], dtype='float')
|
204 |
+
|
205 |
+
x = -pitch_length / 2.0 + SoccerPitch.PENALTY_AREA_LENGTH;
|
206 |
+
dx = SoccerPitch.PENALTY_AREA_LENGTH - SoccerPitch.GOAL_LINE_TO_PENALTY_MARK;
|
207 |
+
y = np.sqrt(SoccerPitch.CENTER_CIRCLE_RADIUS * SoccerPitch.CENTER_CIRCLE_RADIUS - dx * dx);
|
208 |
+
self.bottom_left_16M_penalty_arc_mark = np.array([x, y, 0], dtype='float')
|
209 |
+
|
210 |
+
x = pitch_length / 2.0 - SoccerPitch.PENALTY_AREA_LENGTH;
|
211 |
+
dx = SoccerPitch.PENALTY_AREA_LENGTH - SoccerPitch.GOAL_LINE_TO_PENALTY_MARK;
|
212 |
+
y = np.sqrt(SoccerPitch.CENTER_CIRCLE_RADIUS * SoccerPitch.CENTER_CIRCLE_RADIUS - dx * dx);
|
213 |
+
self.bottom_right_16M_penalty_arc_mark = np.array([x, y, 0], dtype='float')
|
214 |
+
|
215 |
+
# self.set_elevations(elevation)
|
216 |
+
|
217 |
+
self.point_dict = {}
|
218 |
+
self.point_dict["CENTER_MARK"] = self.center_mark
|
219 |
+
self.point_dict["L_PENALTY_MARK"] = self.left_penalty_mark
|
220 |
+
self.point_dict["R_PENALTY_MARK"] = self.right_penalty_mark
|
221 |
+
self.point_dict["TL_PITCH_CORNER"] = self.top_left_corner
|
222 |
+
self.point_dict["BL_PITCH_CORNER"] = self.bottom_left_corner
|
223 |
+
self.point_dict["TR_PITCH_CORNER"] = self.top_right_corner
|
224 |
+
self.point_dict["BR_PITCH_CORNER"] = self.bottom_right_corner
|
225 |
+
self.point_dict["L_PENALTY_AREA_TL_CORNER"] = self.left_penalty_area_top_left_corner
|
226 |
+
self.point_dict["L_PENALTY_AREA_TR_CORNER"] = self.left_penalty_area_top_right_corner
|
227 |
+
self.point_dict["L_PENALTY_AREA_BL_CORNER"] = self.left_penalty_area_bottom_left_corner
|
228 |
+
self.point_dict["L_PENALTY_AREA_BR_CORNER"] = self.left_penalty_area_bottom_right_corner
|
229 |
+
|
230 |
+
self.point_dict["R_PENALTY_AREA_TL_CORNER"] = self.right_penalty_area_top_left_corner
|
231 |
+
self.point_dict["R_PENALTY_AREA_TR_CORNER"] = self.right_penalty_area_top_right_corner
|
232 |
+
self.point_dict["R_PENALTY_AREA_BL_CORNER"] = self.right_penalty_area_bottom_left_corner
|
233 |
+
self.point_dict["R_PENALTY_AREA_BR_CORNER"] = self.right_penalty_area_bottom_right_corner
|
234 |
+
|
235 |
+
self.point_dict["L_GOAL_AREA_TL_CORNER"] = self.left_goal_area_top_left_corner
|
236 |
+
self.point_dict["L_GOAL_AREA_TR_CORNER"] = self.left_goal_area_top_right_corner
|
237 |
+
self.point_dict["L_GOAL_AREA_BL_CORNER"] = self.left_goal_area_bottom_left_corner
|
238 |
+
self.point_dict["L_GOAL_AREA_BR_CORNER"] = self.left_goal_area_bottom_right_corner
|
239 |
+
|
240 |
+
self.point_dict["R_GOAL_AREA_TL_CORNER"] = self.right_goal_area_top_left_corner
|
241 |
+
self.point_dict["R_GOAL_AREA_TR_CORNER"] = self.right_goal_area_top_right_corner
|
242 |
+
self.point_dict["R_GOAL_AREA_BL_CORNER"] = self.right_goal_area_bottom_left_corner
|
243 |
+
self.point_dict["R_GOAL_AREA_BR_CORNER"] = self.right_goal_area_bottom_right_corner
|
244 |
+
|
245 |
+
self.point_dict["L_GOAL_TL_POST"] = self.left_goal_top_left_post
|
246 |
+
self.point_dict["L_GOAL_TR_POST"] = self.left_goal_top_right_post
|
247 |
+
self.point_dict["L_GOAL_BL_POST"] = self.left_goal_bottom_left_post
|
248 |
+
self.point_dict["L_GOAL_BR_POST"] = self.left_goal_bottom_right_post
|
249 |
+
|
250 |
+
self.point_dict["R_GOAL_TL_POST"] = self.right_goal_top_left_post
|
251 |
+
self.point_dict["R_GOAL_TR_POST"] = self.right_goal_top_right_post
|
252 |
+
self.point_dict["R_GOAL_BL_POST"] = self.right_goal_bottom_left_post
|
253 |
+
self.point_dict["R_GOAL_BR_POST"] = self.right_goal_bottom_right_post
|
254 |
+
|
255 |
+
self.point_dict["T_TOUCH_AND_HALFWAY_LINES_INTERSECTION"] = self.halfway_and_top_touch_line_mark
|
256 |
+
self.point_dict["B_TOUCH_AND_HALFWAY_LINES_INTERSECTION"] = self.halfway_and_bottom_touch_line_mark
|
257 |
+
self.point_dict["T_HALFWAY_LINE_AND_CENTER_CIRCLE_INTERSECTION"] = self.halfway_line_and_center_circle_top_mark
|
258 |
+
self.point_dict[
|
259 |
+
"B_HALFWAY_LINE_AND_CENTER_CIRCLE_INTERSECTION"] = self.halfway_line_and_center_circle_bottom_mark
|
260 |
+
self.point_dict["TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"] = self.top_left_16M_penalty_arc_mark
|
261 |
+
self.point_dict["BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"] = self.bottom_left_16M_penalty_arc_mark
|
262 |
+
self.point_dict["TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"] = self.top_right_16M_penalty_arc_mark
|
263 |
+
self.point_dict["BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"] = self.bottom_right_16M_penalty_arc_mark
|
264 |
+
|
265 |
+
self.line_extremities = dict()
|
266 |
+
self.line_extremities["Big rect. left bottom"] = (self.point_dict["L_PENALTY_AREA_BL_CORNER"],
|
267 |
+
self.point_dict["L_PENALTY_AREA_BR_CORNER"])
|
268 |
+
self.line_extremities["Big rect. left top"] = (self.point_dict["L_PENALTY_AREA_TL_CORNER"],
|
269 |
+
self.point_dict["L_PENALTY_AREA_TR_CORNER"])
|
270 |
+
self.line_extremities["Big rect. left main"] = (self.point_dict["L_PENALTY_AREA_TR_CORNER"],
|
271 |
+
self.point_dict["L_PENALTY_AREA_BR_CORNER"])
|
272 |
+
self.line_extremities["Big rect. right bottom"] = (self.point_dict["R_PENALTY_AREA_BL_CORNER"],
|
273 |
+
self.point_dict["R_PENALTY_AREA_BR_CORNER"])
|
274 |
+
self.line_extremities["Big rect. right top"] = (self.point_dict["R_PENALTY_AREA_TL_CORNER"],
|
275 |
+
self.point_dict["R_PENALTY_AREA_TR_CORNER"])
|
276 |
+
self.line_extremities["Big rect. right main"] = (self.point_dict["R_PENALTY_AREA_TL_CORNER"],
|
277 |
+
self.point_dict["R_PENALTY_AREA_BL_CORNER"])
|
278 |
+
|
279 |
+
self.line_extremities["Small rect. left bottom"] = (self.point_dict["L_GOAL_AREA_BL_CORNER"],
|
280 |
+
self.point_dict["L_GOAL_AREA_BR_CORNER"])
|
281 |
+
self.line_extremities["Small rect. left top"] = (self.point_dict["L_GOAL_AREA_TL_CORNER"],
|
282 |
+
self.point_dict["L_GOAL_AREA_TR_CORNER"])
|
283 |
+
self.line_extremities["Small rect. left main"] = (self.point_dict["L_GOAL_AREA_TR_CORNER"],
|
284 |
+
self.point_dict["L_GOAL_AREA_BR_CORNER"])
|
285 |
+
self.line_extremities["Small rect. right bottom"] = (self.point_dict["R_GOAL_AREA_BL_CORNER"],
|
286 |
+
self.point_dict["R_GOAL_AREA_BR_CORNER"])
|
287 |
+
self.line_extremities["Small rect. right top"] = (self.point_dict["R_GOAL_AREA_TL_CORNER"],
|
288 |
+
self.point_dict["R_GOAL_AREA_TR_CORNER"])
|
289 |
+
self.line_extremities["Small rect. right main"] = (self.point_dict["R_GOAL_AREA_TL_CORNER"],
|
290 |
+
self.point_dict["R_GOAL_AREA_BL_CORNER"])
|
291 |
+
|
292 |
+
self.line_extremities["Side line top"] = (self.point_dict["TL_PITCH_CORNER"],
|
293 |
+
self.point_dict["TR_PITCH_CORNER"])
|
294 |
+
self.line_extremities["Side line bottom"] = (self.point_dict["BL_PITCH_CORNER"],
|
295 |
+
self.point_dict["BR_PITCH_CORNER"])
|
296 |
+
self.line_extremities["Side line left"] = (self.point_dict["TL_PITCH_CORNER"],
|
297 |
+
self.point_dict["BL_PITCH_CORNER"])
|
298 |
+
self.line_extremities["Side line right"] = (self.point_dict["TR_PITCH_CORNER"],
|
299 |
+
self.point_dict["BR_PITCH_CORNER"])
|
300 |
+
self.line_extremities["Middle line"] = (self.point_dict["T_TOUCH_AND_HALFWAY_LINES_INTERSECTION"],
|
301 |
+
self.point_dict["B_TOUCH_AND_HALFWAY_LINES_INTERSECTION"])
|
302 |
+
|
303 |
+
self.line_extremities["Goal left crossbar"] = (self.point_dict["L_GOAL_TR_POST"],
|
304 |
+
self.point_dict["L_GOAL_TL_POST"])
|
305 |
+
self.line_extremities["Goal left post left "] = (self.point_dict["L_GOAL_TL_POST"],
|
306 |
+
self.point_dict["L_GOAL_BL_POST"])
|
307 |
+
self.line_extremities["Goal left post right"] = (self.point_dict["L_GOAL_TR_POST"],
|
308 |
+
self.point_dict["L_GOAL_BR_POST"])
|
309 |
+
|
310 |
+
self.line_extremities["Goal right crossbar"] = (self.point_dict["R_GOAL_TL_POST"],
|
311 |
+
self.point_dict["R_GOAL_TR_POST"])
|
312 |
+
self.line_extremities["Goal right post left"] = (self.point_dict["R_GOAL_TL_POST"],
|
313 |
+
self.point_dict["R_GOAL_BL_POST"])
|
314 |
+
self.line_extremities["Goal right post right"] = (self.point_dict["R_GOAL_TR_POST"],
|
315 |
+
self.point_dict["R_GOAL_BR_POST"])
|
316 |
+
self.line_extremities["Circle right"] = (self.point_dict["TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
|
317 |
+
self.point_dict["BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"])
|
318 |
+
self.line_extremities["Circle left"] = (self.point_dict["TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"],
|
319 |
+
self.point_dict["BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"])
|
320 |
+
|
321 |
+
self.line_extremities_keys = dict()
|
322 |
+
self.line_extremities_keys["Big rect. left bottom"] = ("L_PENALTY_AREA_BL_CORNER",
|
323 |
+
"L_PENALTY_AREA_BR_CORNER")
|
324 |
+
self.line_extremities_keys["Big rect. left top"] = ("L_PENALTY_AREA_TL_CORNER",
|
325 |
+
"L_PENALTY_AREA_TR_CORNER")
|
326 |
+
self.line_extremities_keys["Big rect. left main"] = ("L_PENALTY_AREA_TR_CORNER",
|
327 |
+
"L_PENALTY_AREA_BR_CORNER")
|
328 |
+
self.line_extremities_keys["Big rect. right bottom"] = ("R_PENALTY_AREA_BL_CORNER",
|
329 |
+
"R_PENALTY_AREA_BR_CORNER")
|
330 |
+
self.line_extremities_keys["Big rect. right top"] = ("R_PENALTY_AREA_TL_CORNER",
|
331 |
+
"R_PENALTY_AREA_TR_CORNER")
|
332 |
+
self.line_extremities_keys["Big rect. right main"] = ("R_PENALTY_AREA_TL_CORNER",
|
333 |
+
"R_PENALTY_AREA_BL_CORNER")
|
334 |
+
|
335 |
+
self.line_extremities_keys["Small rect. left bottom"] = ("L_GOAL_AREA_BL_CORNER",
|
336 |
+
"L_GOAL_AREA_BR_CORNER")
|
337 |
+
self.line_extremities_keys["Small rect. left top"] = ("L_GOAL_AREA_TL_CORNER",
|
338 |
+
"L_GOAL_AREA_TR_CORNER")
|
339 |
+
self.line_extremities_keys["Small rect. left main"] = ("L_GOAL_AREA_TR_CORNER",
|
340 |
+
"L_GOAL_AREA_BR_CORNER")
|
341 |
+
self.line_extremities_keys["Small rect. right bottom"] = ("R_GOAL_AREA_BL_CORNER",
|
342 |
+
"R_GOAL_AREA_BR_CORNER")
|
343 |
+
self.line_extremities_keys["Small rect. right top"] = ("R_GOAL_AREA_TL_CORNER",
|
344 |
+
"R_GOAL_AREA_TR_CORNER")
|
345 |
+
self.line_extremities_keys["Small rect. right main"] = ("R_GOAL_AREA_TL_CORNER",
|
346 |
+
"R_GOAL_AREA_BL_CORNER")
|
347 |
+
|
348 |
+
self.line_extremities_keys["Side line top"] = ("TL_PITCH_CORNER",
|
349 |
+
"TR_PITCH_CORNER")
|
350 |
+
self.line_extremities_keys["Side line bottom"] = ("BL_PITCH_CORNER",
|
351 |
+
"BR_PITCH_CORNER")
|
352 |
+
self.line_extremities_keys["Side line left"] = ("TL_PITCH_CORNER",
|
353 |
+
"BL_PITCH_CORNER")
|
354 |
+
self.line_extremities_keys["Side line right"] = ("TR_PITCH_CORNER",
|
355 |
+
"BR_PITCH_CORNER")
|
356 |
+
self.line_extremities_keys["Middle line"] = ("T_TOUCH_AND_HALFWAY_LINES_INTERSECTION",
|
357 |
+
"B_TOUCH_AND_HALFWAY_LINES_INTERSECTION")
|
358 |
+
|
359 |
+
self.line_extremities_keys["Goal left crossbar"] = ("L_GOAL_TR_POST",
|
360 |
+
"L_GOAL_TL_POST")
|
361 |
+
self.line_extremities_keys["Goal left post left "] = ("L_GOAL_TL_POST",
|
362 |
+
"L_GOAL_BL_POST")
|
363 |
+
self.line_extremities_keys["Goal left post right"] = ("L_GOAL_TR_POST",
|
364 |
+
"L_GOAL_BR_POST")
|
365 |
+
|
366 |
+
self.line_extremities_keys["Goal right crossbar"] = ("R_GOAL_TL_POST",
|
367 |
+
"R_GOAL_TR_POST")
|
368 |
+
self.line_extremities_keys["Goal right post left"] = ("R_GOAL_TL_POST",
|
369 |
+
"R_GOAL_BL_POST")
|
370 |
+
self.line_extremities_keys["Goal right post right"] = ("R_GOAL_TR_POST",
|
371 |
+
"R_GOAL_BR_POST")
|
372 |
+
self.line_extremities_keys["Circle right"] = ("TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
|
373 |
+
"BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION")
|
374 |
+
self.line_extremities_keys["Circle left"] = ("TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION",
|
375 |
+
"BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION")
|
376 |
+
|
377 |
+
def points(self):
|
378 |
+
return [
|
379 |
+
self.center_mark,
|
380 |
+
self.halfway_and_bottom_touch_line_mark,
|
381 |
+
self.halfway_and_top_touch_line_mark,
|
382 |
+
self.halfway_line_and_center_circle_top_mark,
|
383 |
+
self.halfway_line_and_center_circle_bottom_mark,
|
384 |
+
self.bottom_right_corner,
|
385 |
+
self.bottom_left_corner,
|
386 |
+
self.top_right_corner,
|
387 |
+
self.top_left_corner,
|
388 |
+
self.left_penalty_mark,
|
389 |
+
self.right_penalty_mark,
|
390 |
+
self.left_penalty_area_top_right_corner,
|
391 |
+
self.left_penalty_area_top_left_corner,
|
392 |
+
self.left_penalty_area_bottom_right_corner,
|
393 |
+
self.left_penalty_area_bottom_left_corner,
|
394 |
+
self.right_penalty_area_top_right_corner,
|
395 |
+
self.right_penalty_area_top_left_corner,
|
396 |
+
self.right_penalty_area_bottom_right_corner,
|
397 |
+
self.right_penalty_area_bottom_left_corner,
|
398 |
+
self.left_goal_area_top_right_corner,
|
399 |
+
self.left_goal_area_top_left_corner,
|
400 |
+
self.left_goal_area_bottom_right_corner,
|
401 |
+
self.left_goal_area_bottom_left_corner,
|
402 |
+
self.right_goal_area_top_right_corner,
|
403 |
+
self.right_goal_area_top_left_corner,
|
404 |
+
self.right_goal_area_bottom_right_corner,
|
405 |
+
self.right_goal_area_bottom_left_corner,
|
406 |
+
self.top_left_16M_penalty_arc_mark,
|
407 |
+
self.top_right_16M_penalty_arc_mark,
|
408 |
+
self.bottom_left_16M_penalty_arc_mark,
|
409 |
+
self.bottom_right_16M_penalty_arc_mark,
|
410 |
+
self.left_goal_top_left_post,
|
411 |
+
self.left_goal_top_right_post,
|
412 |
+
self.left_goal_bottom_left_post,
|
413 |
+
self.left_goal_bottom_right_post,
|
414 |
+
self.right_goal_top_left_post,
|
415 |
+
self.right_goal_top_right_post,
|
416 |
+
self.right_goal_bottom_left_post,
|
417 |
+
self.right_goal_bottom_right_post
|
418 |
+
]
|
419 |
+
|
420 |
+
def sample_field_points(self, dist=0.1, dist_circles=0.2):
|
421 |
+
"""
|
422 |
+
Samples each pitch element every dist meters, returns a dictionary associating the class of the element with a list of points sampled along this element.
|
423 |
+
:param dist: the distance in meters between each point sampled
|
424 |
+
:param dist_circles: the distance in meters between each point sampled on circles
|
425 |
+
:return: a dictionary associating the class of the element with a list of points sampled along this element.
|
426 |
+
"""
|
427 |
+
polylines = dict()
|
428 |
+
center = self.point_dict["CENTER_MARK"]
|
429 |
+
fromAngle = 0.
|
430 |
+
toAngle = 2 * np.pi
|
431 |
+
|
432 |
+
if toAngle < fromAngle:
|
433 |
+
toAngle += 2 * np.pi
|
434 |
+
x1 = center[0] + np.cos(fromAngle) * SoccerPitch.CENTER_CIRCLE_RADIUS
|
435 |
+
y1 = center[1] + np.sin(fromAngle) * SoccerPitch.CENTER_CIRCLE_RADIUS
|
436 |
+
z1 = 0.
|
437 |
+
point = np.array((x1, y1, z1))
|
438 |
+
polyline = [point]
|
439 |
+
length = SoccerPitch.CENTER_CIRCLE_RADIUS * (toAngle - fromAngle)
|
440 |
+
nb_pts = int(length / dist_circles)
|
441 |
+
dangle = dist_circles / SoccerPitch.CENTER_CIRCLE_RADIUS
|
442 |
+
for i in range(1, nb_pts):
|
443 |
+
angle = fromAngle + i * dangle
|
444 |
+
x = center[0] + np.cos(angle) * SoccerPitch.CENTER_CIRCLE_RADIUS
|
445 |
+
y = center[1] + np.sin(angle) * SoccerPitch.CENTER_CIRCLE_RADIUS
|
446 |
+
z = 0
|
447 |
+
point = np.array((x, y, z))
|
448 |
+
polyline.append(point)
|
449 |
+
polylines["Circle central"] = polyline
|
450 |
+
for key, line in self.line_extremities.items():
|
451 |
+
|
452 |
+
if "Circle" in key:
|
453 |
+
if key == "Circle right":
|
454 |
+
top = self.point_dict["TR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
|
455 |
+
bottom = self.point_dict["BR_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
|
456 |
+
center = self.point_dict["R_PENALTY_MARK"]
|
457 |
+
toAngle = np.arctan2(top[1] - center[1],
|
458 |
+
top[0] - center[0]) + 2 * np.pi
|
459 |
+
fromAngle = np.arctan2(bottom[1] - center[1],
|
460 |
+
bottom[0] - center[0]) + 2 * np.pi
|
461 |
+
elif key == "Circle left":
|
462 |
+
top = self.point_dict["TL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
|
463 |
+
bottom = self.point_dict["BL_16M_LINE_AND_PENALTY_ARC_INTERSECTION"]
|
464 |
+
center = self.point_dict["L_PENALTY_MARK"]
|
465 |
+
fromAngle = np.arctan2(top[1] - center[1],
|
466 |
+
top[0] - center[0]) + 2 * np.pi
|
467 |
+
toAngle = np.arctan2(bottom[1] - center[1],
|
468 |
+
bottom[0] - center[0]) + 2 * np.pi
|
469 |
+
if toAngle < fromAngle:
|
470 |
+
toAngle += 2 * np.pi
|
471 |
+
x1 = center[0] + np.cos(fromAngle) * SoccerPitch.CENTER_CIRCLE_RADIUS
|
472 |
+
y1 = center[1] + np.sin(fromAngle) * SoccerPitch.CENTER_CIRCLE_RADIUS
|
473 |
+
z1 = 0.
|
474 |
+
xn = center[0] + np.cos(toAngle) * SoccerPitch.CENTER_CIRCLE_RADIUS
|
475 |
+
yn = center[1] + np.sin(toAngle) * SoccerPitch.CENTER_CIRCLE_RADIUS
|
476 |
+
zn = 0.
|
477 |
+
start = np.array((x1, y1, z1))
|
478 |
+
end = np.array((xn, yn, zn))
|
479 |
+
polyline = [start]
|
480 |
+
length = SoccerPitch.CENTER_CIRCLE_RADIUS * (toAngle - fromAngle)
|
481 |
+
nb_pts = int(length / dist_circles)
|
482 |
+
dangle = dist_circles / SoccerPitch.CENTER_CIRCLE_RADIUS
|
483 |
+
for i in range(1, nb_pts + 1):
|
484 |
+
angle = fromAngle + i * dangle
|
485 |
+
x = center[0] + np.cos(angle) * SoccerPitch.CENTER_CIRCLE_RADIUS
|
486 |
+
y = center[1] + np.sin(angle) * SoccerPitch.CENTER_CIRCLE_RADIUS
|
487 |
+
z = 0
|
488 |
+
point = np.array((x, y, z))
|
489 |
+
polyline.append(point)
|
490 |
+
polyline.append(end)
|
491 |
+
polylines[key] = polyline
|
492 |
+
else:
|
493 |
+
start = line[0]
|
494 |
+
end = line[1]
|
495 |
+
|
496 |
+
polyline = [start]
|
497 |
+
|
498 |
+
total_dist = np.sqrt(np.sum(np.square(start - end)))
|
499 |
+
nb_pts = int(total_dist / dist - 1)
|
500 |
+
|
501 |
+
v = end - start
|
502 |
+
v /= np.linalg.norm(v)
|
503 |
+
prev_pt = start
|
504 |
+
for i in range(nb_pts):
|
505 |
+
pt = prev_pt + dist * v
|
506 |
+
prev_pt = pt
|
507 |
+
polyline.append(pt)
|
508 |
+
polyline.append(end)
|
509 |
+
polylines[key] = polyline
|
510 |
+
return polylines
|
511 |
+
|
512 |
+
def get_2d_homogeneous_line(self, line_name):
|
513 |
+
"""
|
514 |
+
For lines belonging to the pitch lawn plane returns its 2D homogenous equation coefficients
|
515 |
+
:param line_name
|
516 |
+
:return: an array containing the three coefficients of the line
|
517 |
+
"""
|
518 |
+
# ensure line in football pitch plane
|
519 |
+
if line_name in self.line_extremities.keys() and \
|
520 |
+
"post" not in line_name and \
|
521 |
+
"crossbar" not in line_name and "Circle" not in line_name:
|
522 |
+
extremities = self.line_extremities[line_name]
|
523 |
+
p1 = np.array([extremities[0][0], extremities[0][1], 1], dtype="float")
|
524 |
+
p2 = np.array([extremities[1][0], extremities[1][1], 1], dtype="float")
|
525 |
+
line = np.cross(p1, p2)
|
526 |
+
|
527 |
+
return line
|
528 |
+
return None
|