Ethan18 commited on
Commit
5c48b81
·
verified ·
1 Parent(s): 3b821c8

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. .gitignore +10 -0
  3. LICENSE +201 -0
  4. README.md +4 -4
  5. app.py +626 -0
  6. app_lam.py +560 -0
  7. assets/images/logo.jpeg +0 -0
  8. assets/images/teaser.jpg +3 -0
  9. configs/inference/lam-20k-8gpu.yaml +130 -0
  10. configs/stylematte_config.json +2311 -0
  11. configs/vhap_tracking/base_tracking_config.yaml +64 -0
  12. external/human_matting/__init__.py +1 -0
  13. external/human_matting/matting_engine.py +66 -0
  14. external/human_matting/stylematte.py +272 -0
  15. external/landmark_detection/FaceBoxesV2/__init__.py +2 -0
  16. external/landmark_detection/FaceBoxesV2/detector.py +39 -0
  17. external/landmark_detection/FaceBoxesV2/faceboxes_detector.py +97 -0
  18. external/landmark_detection/FaceBoxesV2/utils/__init__.py +0 -0
  19. external/landmark_detection/FaceBoxesV2/utils/box_utils.py +276 -0
  20. external/landmark_detection/FaceBoxesV2/utils/build.py +57 -0
  21. external/landmark_detection/FaceBoxesV2/utils/config.py +14 -0
  22. external/landmark_detection/FaceBoxesV2/utils/faceboxes.py +239 -0
  23. external/landmark_detection/FaceBoxesV2/utils/make.sh +3 -0
  24. external/landmark_detection/FaceBoxesV2/utils/nms/__init__.py +0 -0
  25. external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.c +0 -0
  26. external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.py +0 -0
  27. external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.pyx +163 -0
  28. external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.hpp +2 -0
  29. external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.pyx +31 -0
  30. external/landmark_detection/FaceBoxesV2/utils/nms/nms_kernel.cu +144 -0
  31. external/landmark_detection/FaceBoxesV2/utils/nms/py_cpu_nms.py +38 -0
  32. external/landmark_detection/FaceBoxesV2/utils/nms_wrapper.py +15 -0
  33. external/landmark_detection/FaceBoxesV2/utils/prior_box.py +43 -0
  34. external/landmark_detection/FaceBoxesV2/utils/timer.py +40 -0
  35. external/landmark_detection/README.md +110 -0
  36. external/landmark_detection/conf/__init__.py +1 -0
  37. external/landmark_detection/conf/alignment.py +239 -0
  38. external/landmark_detection/conf/base.py +94 -0
  39. external/landmark_detection/config.json +15 -0
  40. external/landmark_detection/data_processor/CheckFaceKeyPoint.py +147 -0
  41. external/landmark_detection/data_processor/align.py +193 -0
  42. external/landmark_detection/data_processor/process_pcd.py +250 -0
  43. external/landmark_detection/evaluate.py +258 -0
  44. external/landmark_detection/infer_folder.py +253 -0
  45. external/landmark_detection/infer_image.py +251 -0
  46. external/landmark_detection/infer_video.py +287 -0
  47. external/landmark_detection/lib/__init__.py +9 -0
  48. external/landmark_detection/lib/backbone/__init__.py +5 -0
  49. external/landmark_detection/lib/backbone/core/coord_conv.py +157 -0
  50. external/landmark_detection/lib/backbone/stackedHGNetV1.py +307 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/images/teaser.jpg filter=lfs diff=lfs merge=lfs -text
37
+ new_wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
38
+ wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
39
+ wheels/gradio_gaussian_render-0.0.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
40
+ wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
41
+ wheels/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ build/
3
+ *.so
4
+ assets/sample_input/
5
+ assets/sample_motion/
6
+ exps/
7
+ pretrain_model/
8
+ pretrained_models/
9
+ model_zoo/
10
+ tracking_output/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
  title: LAM H5
3
  emoji: 🔥
4
- colorFrom: green
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.26.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: LAM with H5 cross-platform rendering.
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: LAM H5
3
  emoji: 🔥
4
+ colorFrom: purple
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.23.3
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: LAM with H5 rendering
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ os.system("rm -rf /data-nvme/zerogpu-offload/")
17
+ os.system("pip3 install chumpy")
18
+ os.system("pip3 install Cython")
19
+ os.system("pip3 install ./wheels/gradio_gaussian_render-0.0.1-py3-none-any.whl")
20
+ os.system("pip3 install ./new_wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl")
21
+ os.system("pip3 install ./wheels/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl")
22
+ os.system("pip3 install ./wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl --force-reinstall")
23
+ os.system(
24
+ "pip3 install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html"
25
+ )
26
+ os.system("pip3 install numpy==1.23.0")
27
+
28
+ import cv2
29
+ import sys
30
+ import base64
31
+ import subprocess
32
+
33
+ import gradio as gr
34
+ import numpy as np
35
+ from PIL import Image
36
+ import argparse
37
+ from omegaconf import OmegaConf
38
+
39
+ import torch
40
+ import zipfile
41
+ from glob import glob
42
+ import moviepy.editor as mpy
43
+ from lam.runners.infer.head_utils import prepare_motion_seqs, preprocess_image
44
+
45
+ import spaces
46
+
47
+
48
+ h5_rendering = True
49
+ if h5_rendering:
50
+ from gradio_gaussian_render import gaussian_render
51
+
52
+
53
+ def compile_module(subfolder, script):
54
+ try:
55
+ # Save the current working directory
56
+ current_dir = os.getcwd()
57
+ # Change directory to the subfolder
58
+ os.chdir(os.path.join(current_dir, subfolder))
59
+ # Run the compilation command
60
+ result = subprocess.run(
61
+ ["sh", script],
62
+ capture_output=True,
63
+ text=True,
64
+ check=True
65
+ )
66
+ # Print the compilation output
67
+ print("Compilation output:", result.stdout)
68
+
69
+ except Exception as e:
70
+ # Print any error that occurred
71
+ print(f"An error occurred: {e}")
72
+ finally:
73
+ # Ensure returning to the original directory
74
+ os.chdir(current_dir)
75
+ print("Returned to the original directory.")
76
+
77
+
78
+ # compile flame_tracking dependence submodule
79
+ compile_module("external/landmark_detection/FaceBoxesV2/utils/", "make.sh")
80
+ from tools.flame_tracking_single_image import FlameTrackingSingleImage
81
+
82
+
83
+ def launch_pretrained():
84
+ from huggingface_hub import snapshot_download, hf_hub_download
85
+ # launch pretrained for flame tracking.
86
+ hf_hub_download(repo_id='3DAIGC/LAM-assets',
87
+ repo_type='model',
88
+ filename='LAM_assets.tar',
89
+ local_dir='./')
90
+ os.system('tar -xf LAM_assets.tar && rm LAM_assets.tar')
91
+ # launch human model files
92
+ hf_hub_download(repo_id='3DAIGC/LAM-assets',
93
+ repo_type='model',
94
+ filename='thirdparty_models.tar',
95
+ local_dir='./')
96
+ os.system('tar -xf thirdparty_models.tar && rm thirdparty_models.tar')
97
+ # launch thirdparty applications
98
+ hf_hub_download(repo_id='3DAIGC/LAM-assets',
99
+ repo_type='model',
100
+ filename='thirdparties.tar',
101
+ local_dir='./tmp/')
102
+ os.system('tar -xf ./tmp/thirdparties.tar && rm -r tmp/')
103
+ # launch pretrained for LAM
104
+ model_dir = hf_hub_download(repo_id="3DAIGC/LAM-20K", repo_type="model", local_dir="./model_zoo/lam_models/releases/lam/lam-20k/step_045500/", filename="config.json")
105
+ print(model_dir)
106
+ model_dir = hf_hub_download(repo_id="3DAIGC/LAM-20K", repo_type="model", local_dir="./model_zoo/lam_models/releases/lam/lam-20k/step_045500/", filename="model.safetensors")
107
+ print(model_dir)
108
+ model_dir = hf_hub_download(repo_id="3DAIGC/LAM-20K", repo_type="model", local_dir="./model_zoo/lam_models/releases/lam/lam-20k/step_045500/", filename="README.md")
109
+ print(model_dir)
110
+
111
+
112
+ def launch_env_not_compile_with_cuda():
113
+ os.system('pip install chumpy')
114
+ os.system('pip install numpy==1.23.0')
115
+ os.system(
116
+ 'pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt251/download.html'
117
+ )
118
+
119
+
120
+ def assert_input_image(input_image):
121
+ if input_image is None:
122
+ raise gr.Error('No image selected or uploaded!')
123
+
124
+
125
+ def prepare_working_dir():
126
+ import tempfile
127
+ working_dir = tempfile.TemporaryDirectory()
128
+ return working_dir
129
+
130
+
131
+ def init_preprocessor():
132
+ from lam.utils.preprocess import Preprocessor
133
+ global preprocessor
134
+ preprocessor = Preprocessor()
135
+
136
+
137
+ def preprocess_fn(image_in: np.ndarray, remove_bg: bool, recenter: bool,
138
+ working_dir):
139
+ image_raw = os.path.join(working_dir.name, 'raw.png')
140
+ with Image.fromarray(image_in) as img:
141
+ img.save(image_raw)
142
+ image_out = os.path.join(working_dir.name, 'rembg.png')
143
+ success = preprocessor.preprocess(image_path=image_raw,
144
+ save_path=image_out,
145
+ rmbg=remove_bg,
146
+ recenter=recenter)
147
+ assert success, f'Failed under preprocess_fn!'
148
+ return image_out
149
+
150
+
151
+ def get_image_base64(path):
152
+ with open(path, 'rb') as image_file:
153
+ encoded_string = base64.b64encode(image_file.read()).decode()
154
+ return f'data:image/png;base64,{encoded_string}'
155
+
156
+
157
+ def do_softlink(working_dir, tgt_dir="./runtime_data"):
158
+ os.system(f"rm {tgt_dir}")
159
+ cmd = f"ln -s {working_dir} ./runtime_data"
160
+ os.system(cmd)
161
+ return cmd
162
+
163
+
164
+ def doRender(working_dir):
165
+ working_dir = working_dir.name
166
+ cmd = do_softlink(working_dir)
167
+ print('='*100, "\n"+cmd, '\ndo render', "\n"+"="*100)
168
+ os.system("ls ./runtime_data")
169
+
170
+
171
+ def save_images2video(img_lst, v_pth, fps=30):
172
+ from moviepy.editor import ImageSequenceClip
173
+ # Ensure all images are in uint8 format
174
+ images = [image.astype(np.uint8) for image in img_lst]
175
+
176
+ # Create an ImageSequenceClip from the list of images
177
+ clip = ImageSequenceClip(images, fps=fps)
178
+
179
+ # Write the clip to a video file
180
+ clip.write_videofile(v_pth, codec='libx264')
181
+
182
+ print(f"Video saved successfully at {v_pth}")
183
+
184
+
185
+ def add_audio_to_video(video_path, out_path, audio_path, fps=30):
186
+ # Import necessary modules from moviepy
187
+ from moviepy.editor import VideoFileClip, AudioFileClip
188
+
189
+ # Load video file into VideoFileClip object
190
+ video_clip = VideoFileClip(video_path)
191
+
192
+ # Load audio file into AudioFileClip object
193
+ audio_clip = AudioFileClip(audio_path)
194
+
195
+ # Hard code clip audio
196
+ if audio_clip.duration > 10:
197
+ audio_clip = audio_clip.subclip(0, 10)
198
+
199
+ # Attach audio clip to video clip (replaces existing audio)
200
+ video_clip_with_audio = video_clip.set_audio(audio_clip)
201
+
202
+ # Export final video with audio using standard codecs
203
+ video_clip_with_audio.write_videofile(out_path, codec='libx264', audio_codec='aac', fps=fps)
204
+
205
+ print(f"Audio added successfully at {out_path}")
206
+
207
+
208
+ def parse_configs():
209
+
210
+ parser = argparse.ArgumentParser()
211
+ parser.add_argument("--config", type=str)
212
+ parser.add_argument("--infer", type=str)
213
+ args, unknown = parser.parse_known_args()
214
+
215
+ cfg = OmegaConf.create()
216
+ cli_cfg = OmegaConf.from_cli(unknown)
217
+
218
+ # parse from ENV
219
+ if os.environ.get("APP_INFER") is not None:
220
+ args.infer = os.environ.get("APP_INFER")
221
+ if os.environ.get("APP_MODEL_NAME") is not None:
222
+ cli_cfg.model_name = os.environ.get("APP_MODEL_NAME")
223
+
224
+ args.config = args.infer if args.config is None else args.config
225
+
226
+ if args.config is not None:
227
+ cfg_train = OmegaConf.load(args.config)
228
+ cfg.source_size = cfg_train.dataset.source_image_res
229
+ try:
230
+ cfg.src_head_size = cfg_train.dataset.src_head_size
231
+ except:
232
+ cfg.src_head_size = 112
233
+ cfg.render_size = cfg_train.dataset.render_image.high
234
+ _relative_path = os.path.join(
235
+ cfg_train.experiment.parent,
236
+ cfg_train.experiment.child,
237
+ os.path.basename(cli_cfg.model_name).split("_")[-1],
238
+ )
239
+
240
+ cfg.save_tmp_dump = os.path.join("exps", "save_tmp", _relative_path)
241
+ cfg.image_dump = os.path.join("exps", "images", _relative_path)
242
+ cfg.video_dump = os.path.join("exps", "videos", _relative_path) # output path
243
+
244
+ if args.infer is not None:
245
+ cfg_infer = OmegaConf.load(args.infer)
246
+ cfg.merge_with(cfg_infer)
247
+ cfg.setdefault(
248
+ "save_tmp_dump", os.path.join("exps", cli_cfg.model_name, "save_tmp")
249
+ )
250
+ cfg.setdefault("image_dump", os.path.join("exps", cli_cfg.model_name, "images"))
251
+ cfg.setdefault(
252
+ "video_dump", os.path.join("dumps", cli_cfg.model_name, "videos")
253
+ )
254
+ cfg.setdefault("mesh_dump", os.path.join("dumps", cli_cfg.model_name, "meshes"))
255
+
256
+ cfg.motion_video_read_fps = 30
257
+ cfg.merge_with(cli_cfg)
258
+
259
+ cfg.setdefault("logger", "INFO")
260
+
261
+ assert cfg.model_name is not None, "model_name is required"
262
+
263
+ return cfg, cfg_train
264
+
265
+
266
+ def create_zip_archive(output_zip='runtime_data/h5_render_data.zip', base_vid="nice", in_fd="./runtime_data"):
267
+ flame_params_pth = os.path.join("./assets/sample_motion/export", base_vid, "flame_params.json")
268
+ file_lst = [
269
+ f'{in_fd}/lbs_weight_20k.json', f'{in_fd}/offset.ply', f'{in_fd}/skin.glb',
270
+ f'{in_fd}/vertex_order.json', f'{in_fd}/bone_tree.json',
271
+ flame_params_pth
272
+ ]
273
+ try:
274
+ # Create a new ZIP file in write mode
275
+ with zipfile.ZipFile(output_zip, 'w') as zipf:
276
+ # List all files in the specified directory
277
+ for file_path in file_lst:
278
+ zipf.write(file_path, arcname=os.path.join("h5_render_data", os.path.basename(file_path)))
279
+ print(f"Archive created successfully: {output_zip}")
280
+ except Exception as e:
281
+ print(f"An error occurred: {e}")
282
+
283
+
284
+ def demo_lam(flametracking, lam, cfg):
285
+
286
+ @spaces.GPU(duration=180)
287
+ def core_fn(image_path: str, video_params, working_dir):
288
+ image_raw = os.path.join(working_dir.name, "raw.png")
289
+ with Image.open(image_path).convert('RGB') as img:
290
+ img.save(image_raw)
291
+
292
+ base_vid = os.path.basename(video_params).split(".")[0]
293
+ flame_params_dir = os.path.join("./assets/sample_motion/export", base_vid, "flame_param")
294
+ base_iid = os.path.basename(image_path).split('.')[0]
295
+ image_path = os.path.join("./assets/sample_input", base_iid, "images/00000_00.png")
296
+
297
+ dump_video_path = os.path.join(working_dir.name, "output.mp4")
298
+ dump_image_path = os.path.join(working_dir.name, "output.png")
299
+
300
+ # prepare dump paths
301
+ omit_prefix = os.path.dirname(image_raw)
302
+ image_name = os.path.basename(image_raw)
303
+ uid = image_name.split(".")[0]
304
+ subdir_path = os.path.dirname(image_raw).replace(omit_prefix, "")
305
+ subdir_path = (
306
+ subdir_path[1:] if subdir_path.startswith("/") else subdir_path
307
+ )
308
+ print("subdir_path and uid:", subdir_path, uid)
309
+
310
+ motion_seqs_dir = flame_params_dir
311
+
312
+ dump_image_dir = os.path.dirname(dump_image_path)
313
+ os.makedirs(dump_image_dir, exist_ok=True)
314
+
315
+ print(image_raw, motion_seqs_dir, dump_image_dir, dump_video_path)
316
+
317
+ dump_tmp_dir = dump_image_dir
318
+
319
+
320
+ motion_img_need_mask = cfg.get("motion_img_need_mask", False) # False
321
+ vis_motion = cfg.get("vis_motion", False) # False
322
+
323
+ # preprocess input image: segmentation, flame params estimation
324
+ return_code = flametracking.preprocess(image_raw)
325
+ assert (return_code == 0), "flametracking preprocess failed!"
326
+ return_code = flametracking.optimize()
327
+ assert (return_code == 0), "flametracking optimize failed!"
328
+ return_code, output_dir = flametracking.export()
329
+ assert (return_code == 0), "flametracking export failed!"
330
+ image_path = os.path.join(output_dir, "images/00000_00.png")
331
+ mask_path = os.path.join(output_dir, "fg_masks/00000_00.png")
332
+ print("image_path:", image_path, "\n"+"mask_path:", mask_path)
333
+
334
+ aspect_standard = 1.0/1.0
335
+ source_size = cfg.source_size
336
+ render_size = cfg.render_size
337
+ render_fps = 30
338
+ # prepare reference image
339
+ image, _, _, shape_param = preprocess_image(image_path, mask_path=mask_path, intr=None, pad_ratio=0, bg_color=1.,
340
+ max_tgt_size=None, aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1.0],
341
+ render_tgt_size=source_size, multiply=14, need_mask=True, get_shape_param=True)
342
+
343
+ # save masked image for vis
344
+ save_ref_img_path = os.path.join(dump_tmp_dir, "output.png")
345
+ vis_ref_img = (image[0].permute(1, 2, 0).cpu().detach().numpy() * 255).astype(np.uint8)
346
+ Image.fromarray(vis_ref_img).save(save_ref_img_path)
347
+
348
+ # prepare motion seq
349
+ src = image_path.split('/')[-3]
350
+ driven = motion_seqs_dir.split('/')[-2]
351
+ src_driven = [src, driven]
352
+ motion_seq = prepare_motion_seqs(motion_seqs_dir, None, save_root=dump_tmp_dir, fps=render_fps,
353
+ bg_color=1., aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1,0],
354
+ render_image_res=render_size, multiply=16,
355
+ need_mask=motion_img_need_mask, vis_motion=vis_motion,
356
+ shape_param=shape_param, test_sample=False, cross_id=False, src_driven=src_driven,
357
+ max_squen_length=300)
358
+
359
+ # start inference
360
+ motion_seq["flame_params"]["betas"] = shape_param.unsqueeze(0)
361
+ device, dtype = "cuda", torch.float32
362
+ print("start to inference...................")
363
+ with torch.no_grad():
364
+ # TODO check device and dtype
365
+ res = lam.infer_single_view(image.unsqueeze(0).to(device, dtype), None, None,
366
+ render_c2ws=motion_seq["render_c2ws"].to(device),
367
+ render_intrs=motion_seq["render_intrs"].to(device),
368
+ render_bg_colors=motion_seq["render_bg_colors"].to(device),
369
+ flame_params={k:v.to(device) for k, v in motion_seq["flame_params"].items()})
370
+
371
+ # save h5 rendering info
372
+ if h5_rendering:
373
+ res['cano_gs_lst'][0].save_ply(os.path.join(working_dir.name, "offset.ply"), rgb2sh=False, offset2xyz=True)
374
+
375
+ h5_fd = working_dir.name
376
+ lam.renderer.flame_model.save_h5_info(shape_param.unsqueeze(0).cuda(), fd=h5_fd)
377
+ res['cano_gs_lst'][0].save_ply(os.path.join(h5_fd, "offset.ply"), rgb2sh=False, offset2xyz=True)
378
+ cmd = do_softlink(h5_fd)
379
+ cmd = "thirdparties/blender/blender --background --python 'tools/generateGLBWithBlender_v2.py'"
380
+ os.system(cmd)
381
+ output_zip = os.path.join(h5_fd, "h5_render_data.zip")
382
+ create_zip_archive(output_zip=output_zip, base_vid=base_vid, in_fd=h5_fd)
383
+
384
+
385
+ rgb = res["comp_rgb"].detach().cpu().numpy() # [Nv, H, W, 3], 0-1
386
+ mask = res["comp_mask"].detach().cpu().numpy() # [Nv, H, W, 3], 0-1
387
+ mask[mask < 0.5] = 0.0
388
+ rgb = rgb * mask + (1 - mask) * 1
389
+ rgb = (np.clip(rgb, 0, 1.0) * 255).astype(np.uint8)
390
+ if vis_motion:
391
+ vis_ref_img = np.tile(
392
+ cv2.resize(vis_ref_img, (rgb[0].shape[1], rgb[0].shape[0]), interpolation=cv2.INTER_AREA)[None, :, :, :],
393
+ (rgb.shape[0], 1, 1, 1),
394
+ )
395
+ rgb = np.concatenate([vis_ref_img, rgb, motion_seq["vis_motion_render"]], axis=2)
396
+
397
+ os.makedirs(os.path.dirname(dump_video_path), exist_ok=True)
398
+
399
+ save_images2video(rgb, dump_video_path, render_fps)
400
+ audio_path = os.path.join("./assets/sample_motion/export", base_vid, base_vid+".wav")
401
+ dump_video_path_wa = dump_video_path.replace(".mp4", "_audio.mp4")
402
+ add_audio_to_video(dump_video_path, dump_video_path_wa, audio_path)
403
+
404
+ return dump_image_path, dump_video_path_wa
405
+
406
+ with gr.Blocks(analytics_enabled=False, delete_cache=[3600, 3600]) as demo:
407
+
408
+ logo_url = './assets/images/logo.jpeg'
409
+ logo_base64 = get_image_base64(logo_url)
410
+ gr.HTML(f"""
411
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
412
+ <div>
413
+ <h1> <img src="{logo_base64}" style='height:35px; display:inline-block;'/> Large Avatar Model for One-shot Animatable Gaussian Head</h1>
414
+ </div>
415
+ </div>
416
+ """)
417
+ gr.HTML(
418
+ """
419
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center; margin: 20px; gap: 10px;">
420
+ <a class="flex-item" href="https://arxiv.org/abs/2502.17796" target="_blank">
421
+ <img src="https://img.shields.io/badge/Paper-arXiv-darkred.svg" alt="arXiv Paper">
422
+ </a>
423
+ <a class="flex-item" href="https://aigc3d.github.io/projects/LAM/" target="_blank">
424
+ <img src="https://img.shields.io/badge/Project-LAM-blue" alt="Project Page">
425
+ </a>
426
+ <a class="flex-item" href="https://github.com/aigc3d/LAM" target="_blank">
427
+ <img src="https://img.shields.io/github/stars/aigc3d/LAM?label=Github%20★&logo=github&color=C8C" alt="badge-github-stars">
428
+ </a>
429
+ <a class="flex-item" href="https://youtu.be/FrfE3RYSKhk" target="_blank">
430
+ <img src="https://img.shields.io/badge/Youtube-Video-red.svg" alt="Video">
431
+ </a>
432
+ </div>
433
+ """
434
+ )
435
+
436
+ gr.HTML("""<div style="margin-top: -10px">
437
+ <p style="margin: 4px 0; line-height: 1.2"><h4 style="color: black; margin: 2px 0">Notes1: Inputing front-face images or face orientation close to the driven signal gets better results.</h4></p>
438
+ <p style="margin: 4px 0; line-height: 1.2"><h4 style="color: black; margin: 2px 0">Notes2: Due to computational constraints with Hugging Face's ZeroGPU infrastructure, 3D avatar generation requires ~1 minute per instance.</h4></p>
439
+ <p style="margin: 4px 0; line-height: 1.2"><h4 style="color: black; margin: 2px 0">Notes3: Using LAM-20K model (lower quality than premium LAM-80K) to mitigate processing latency.</h4></p>
440
+ </div>""")
441
+
442
+ # DISPLAY
443
+ with gr.Row():
444
+
445
+ with gr.Column(variant='panel', scale=1):
446
+ with gr.Tabs(elem_id='lam_input_image'):
447
+ with gr.TabItem('Input Image'):
448
+ with gr.Row():
449
+ input_image = gr.Image(label='Input Image',
450
+ image_mode='RGB',
451
+ height=480,
452
+ width=270,
453
+ sources='upload',
454
+ type='filepath',
455
+ elem_id='content_image')
456
+ # EXAMPLES
457
+ with gr.Row():
458
+ examples = [
459
+ ['assets/sample_input/messi.png'],
460
+ ['assets/sample_input/status.png'],
461
+ ['assets/sample_input/james.png'],
462
+ ['assets/sample_input/cluo.jpg'],
463
+ ['assets/sample_input/dufu.jpg'],
464
+ ['assets/sample_input/libai.jpg'],
465
+ ['assets/sample_input/barbara.jpg'],
466
+ ['assets/sample_input/pop.png'],
467
+ ['assets/sample_input/musk.jpg'],
468
+ ['assets/sample_input/speed.jpg'],
469
+ ['assets/sample_input/zhouxingchi.jpg'],
470
+ ]
471
+ gr.Examples(
472
+ examples=examples,
473
+ inputs=[input_image],
474
+ examples_per_page=20
475
+ )
476
+
477
+ with gr.Column():
478
+ with gr.Tabs(elem_id='lam_input_video'):
479
+ with gr.TabItem('Input Video'):
480
+ with gr.Row():
481
+ video_input = gr.Video(label='Input Video',
482
+ height=480,
483
+ width=270,
484
+ interactive=False)
485
+
486
+ examples = ['./assets/sample_motion/export/Speeding_Scandal/Speeding_Scandal.mp4',
487
+ './assets/sample_motion/export/Look_In_My_Eyes/Look_In_My_Eyes.mp4',
488
+ './assets/sample_motion/export/D_ANgelo_Dinero/D_ANgelo_Dinero.mp4',
489
+ './assets/sample_motion/export/Michael_Wayne_Rosen/Michael_Wayne_Rosen.mp4',
490
+ './assets/sample_motion/export/I_Am_Iron_Man/I_Am_Iron_Man.mp4',
491
+ './assets/sample_motion/export/Anti_Drugs/Anti_Drugs.mp4',
492
+ './assets/sample_motion/export/Pen_Pineapple_Apple_Pen/Pen_Pineapple_Apple_Pen.mp4',
493
+ './assets/sample_motion/export/Joe_Biden/Joe_Biden.mp4',
494
+ './assets/sample_motion/export/Donald_Trump/Donald_Trump.mp4',
495
+ './assets/sample_motion/export/Taylor_Swift/Taylor_Swift.mp4',
496
+ './assets/sample_motion/export/GEM/GEM.mp4',
497
+ './assets/sample_motion/export/The_Shawshank_Redemption/The_Shawshank_Redemption.mp4'
498
+ ]
499
+ print("Video example list {}".format(examples))
500
+
501
+ gr.Examples(
502
+ examples=examples,
503
+ inputs=[video_input],
504
+ examples_per_page=20,
505
+ )
506
+ with gr.Column(variant='panel', scale=1):
507
+ with gr.Tabs(elem_id='lam_processed_image'):
508
+ with gr.TabItem('Processed Image'):
509
+ with gr.Row():
510
+ processed_image = gr.Image(
511
+ label='Processed Image',
512
+ image_mode='RGBA',
513
+ type='filepath',
514
+ elem_id='processed_image',
515
+ height=480,
516
+ width=270,
517
+ interactive=False)
518
+
519
+ with gr.Column(variant='panel', scale=1):
520
+ with gr.Tabs(elem_id='lam_render_video'):
521
+ with gr.TabItem('Rendered Video'):
522
+ with gr.Row():
523
+ output_video = gr.Video(label='Rendered Video',
524
+ format='mp4',
525
+ height=480,
526
+ width=270,
527
+ autoplay=True)
528
+
529
+ # SETTING
530
+ with gr.Row():
531
+ with gr.Column(variant='panel', scale=1):
532
+ submit = gr.Button('Generate',
533
+ elem_id='lam_generate',
534
+ variant='primary')
535
+
536
+ if h5_rendering:
537
+ gr.HTML(f"""
538
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
539
+ <div>
540
+ <h2> Cross-platform H5 Rendering</h2>
541
+ </div>
542
+ </div>
543
+ """)
544
+ gr.set_static_paths("runtime_data/")
545
+ assetPrefix = 'gradio_api/file=runtime_data/'
546
+ with gr.Row():
547
+ gs = gaussian_render(width = 300, height = 400, assets = assetPrefix + 'h5_render_data.zip')
548
+
549
+ working_dir = gr.State()
550
+ submit.click(
551
+ fn=assert_input_image,
552
+ inputs=[input_image],
553
+ queue=False,
554
+ ).success(
555
+ fn=prepare_working_dir,
556
+ outputs=[working_dir],
557
+ queue=False,
558
+ ).success(
559
+ fn=core_fn,
560
+ inputs=[input_image, video_input,
561
+ working_dir],
562
+ outputs=[processed_image, output_video],
563
+ ).success(
564
+ doRender,
565
+ inputs=[working_dir],
566
+ js='''() => window.start()'''
567
+ )
568
+
569
+ demo.queue()
570
+ demo.launch()
571
+
572
+
573
+ def _build_model(cfg):
574
+ from lam.models import ModelLAM
575
+ from safetensors.torch import load_file
576
+
577
+ model = ModelLAM(**cfg.model)
578
+ resume = os.path.join(cfg.model_name, "model.safetensors")
579
+ print("="*100)
580
+ print("loading pretrained weight from:", resume)
581
+ if resume.endswith('safetensors'):
582
+ ckpt = load_file(resume, device='cpu')
583
+ else:
584
+ ckpt = torch.load(resume, map_location='cpu')
585
+ state_dict = model.state_dict()
586
+ for k, v in ckpt.items():
587
+ if k in state_dict:
588
+ if state_dict[k].shape == v.shape:
589
+ state_dict[k].copy_(v)
590
+ else:
591
+ print(f"WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.")
592
+ else:
593
+ print(f"WARN] unexpected param {k}: {v.shape}")
594
+ print("finish loading pretrained weight from:", resume)
595
+ print("="*100)
596
+ return model
597
+
598
+
599
+ def launch_gradio_app():
600
+
601
+ os.environ.update({
602
+ 'APP_ENABLED': '1',
603
+ 'APP_MODEL_NAME':
604
+ './model_zoo/lam_models/releases/lam/lam-20k/step_045500/',
605
+ 'APP_INFER': './configs/inference/lam-20k-8gpu.yaml',
606
+ 'APP_TYPE': 'infer.lam',
607
+ 'NUMBA_THREADING_LAYER': 'omp',
608
+ })
609
+
610
+ cfg, _ = parse_configs()
611
+ lam = _build_model(cfg)
612
+ lam.to('cuda')
613
+
614
+ flametracking = FlameTrackingSingleImage(output_dir='tracking_output',
615
+ alignment_model_path='./model_zoo/flame_tracking_models/68_keypoints_model.pkl',
616
+ vgghead_model_path='./model_zoo/flame_tracking_models/vgghead/vgg_heads_l.trcd',
617
+ human_matting_path='./model_zoo/flame_tracking_models/matting/stylematte_synth.pt',
618
+ facebox_model_path='./model_zoo/flame_tracking_models/FaceBoxesV2.pth',
619
+ detect_iris_landmarks=False)
620
+
621
+ demo_lam(flametracking, lam, cfg)
622
+
623
+
624
+ if __name__ == '__main__':
625
+ launch_pretrained()
626
+ launch_gradio_app()
app_lam.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024-2025, The Alibaba 3DAIGC Team Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import cv2
17
+ import sys
18
+ import base64
19
+ import subprocess
20
+
21
+ import gradio as gr
22
+ import numpy as np
23
+ from PIL import Image
24
+ import argparse
25
+ from omegaconf import OmegaConf
26
+
27
+ import torch
28
+ import zipfile
29
+ from glob import glob
30
+ import moviepy.editor as mpy
31
+ from tools.flame_tracking_single_image import FlameTrackingSingleImage
32
+ from lam.runners.infer.head_utils import prepare_motion_seqs, preprocess_image
33
+
34
+ try:
35
+ import spaces
36
+ except:
37
+ pass
38
+
39
+
40
+ h5_rendering = True
41
+ from gradio_gaussian_render import gaussian_render
42
+
43
+
44
+ def launch_env_not_compile_with_cuda():
45
+ os.system('pip install chumpy')
46
+ os.system('pip install numpy==1.23.0')
47
+ os.system(
48
+ 'pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt251/download.html'
49
+ )
50
+
51
+
52
+ def assert_input_image(input_image):
53
+ if input_image is None:
54
+ raise gr.Error('No image selected or uploaded!')
55
+
56
+
57
+ def prepare_working_dir():
58
+ import tempfile
59
+ working_dir = tempfile.TemporaryDirectory()
60
+ return working_dir
61
+
62
+
63
+ def init_preprocessor():
64
+ from lam.utils.preprocess import Preprocessor
65
+ global preprocessor
66
+ preprocessor = Preprocessor()
67
+
68
+
69
+ def preprocess_fn(image_in: np.ndarray, remove_bg: bool, recenter: bool,
70
+ working_dir):
71
+ image_raw = os.path.join(working_dir.name, 'raw.png')
72
+ with Image.fromarray(image_in) as img:
73
+ img.save(image_raw)
74
+ image_out = os.path.join(working_dir.name, 'rembg.png')
75
+ success = preprocessor.preprocess(image_path=image_raw,
76
+ save_path=image_out,
77
+ rmbg=remove_bg,
78
+ recenter=recenter)
79
+ assert success, f'Failed under preprocess_fn!'
80
+ return image_out
81
+
82
+
83
+ def get_image_base64(path):
84
+ with open(path, 'rb') as image_file:
85
+ encoded_string = base64.b64encode(image_file.read()).decode()
86
+ return f'data:image/png;base64,{encoded_string}'
87
+
88
+
89
+ def do_softlink(working_dir, tgt_dir="./runtime_data"):
90
+ os.system(f"rm {tgt_dir}")
91
+ cmd = f"ln -s {working_dir} ./runtime_data"
92
+ os.system(cmd)
93
+ return cmd
94
+
95
+
96
+ def doRender(working_dir):
97
+ working_dir = working_dir.name
98
+ cmd = do_softlink(working_dir)
99
+ print('='*100, "\n"+cmd, '\ndo render', "\n"+"="*100)
100
+
101
+
102
+ def save_images2video(img_lst, v_pth, fps):
103
+ from moviepy.editor import ImageSequenceClip
104
+ # Ensure all images are in uint8 format
105
+ images = [image.astype(np.uint8) for image in img_lst]
106
+
107
+ # Create an ImageSequenceClip from the list of images
108
+ clip = ImageSequenceClip(images, fps=fps)
109
+
110
+ # Write the clip to a video file
111
+ clip.write_videofile(v_pth, codec='libx264')
112
+
113
+ print(f"Video saved successfully at {v_pth}")
114
+
115
+
116
+ def add_audio_to_video(video_path, out_path, audio_path, fps=30):
117
+ # Import necessary modules from moviepy
118
+ from moviepy.editor import VideoFileClip, AudioFileClip
119
+
120
+ # Load video file into VideoFileClip object
121
+ video_clip = VideoFileClip(video_path)
122
+
123
+ # Load audio file into AudioFileClip object
124
+ audio_clip = AudioFileClip(audio_path)
125
+
126
+ # Hard code clip audio
127
+ """
128
+ if audio_clip.duration > 10:
129
+ audio_clip = audio_clip.subclip(0, 10)
130
+ """
131
+
132
+ # Attach audio clip to video clip (replaces existing audio)
133
+ video_clip_with_audio = video_clip.set_audio(audio_clip)
134
+
135
+ # Export final video with audio using standard codecs
136
+ video_clip_with_audio.write_videofile(out_path, codec='libx264', audio_codec='aac', fps=fps)
137
+
138
+ print(f"Audio added successfully at {out_path}")
139
+
140
+ def parse_configs():
141
+
142
+ parser = argparse.ArgumentParser()
143
+ parser.add_argument("--config", type=str)
144
+ parser.add_argument("--infer", type=str)
145
+ args, unknown = parser.parse_known_args()
146
+
147
+ cfg = OmegaConf.create()
148
+ cli_cfg = OmegaConf.from_cli(unknown)
149
+
150
+ # parse from ENV
151
+ if os.environ.get("APP_INFER") is not None:
152
+ args.infer = os.environ.get("APP_INFER")
153
+ if os.environ.get("APP_MODEL_NAME") is not None:
154
+ cli_cfg.model_name = os.environ.get("APP_MODEL_NAME")
155
+
156
+ args.config = args.infer if args.config is None else args.config
157
+
158
+ if args.config is not None:
159
+ cfg_train = OmegaConf.load(args.config)
160
+ cfg.source_size = cfg_train.dataset.source_image_res
161
+ try:
162
+ cfg.src_head_size = cfg_train.dataset.src_head_size
163
+ except:
164
+ cfg.src_head_size = 112
165
+ cfg.render_size = cfg_train.dataset.render_image.high
166
+ _relative_path = os.path.join(
167
+ cfg_train.experiment.parent,
168
+ cfg_train.experiment.child,
169
+ os.path.basename(cli_cfg.model_name).split("_")[-1],
170
+ )
171
+
172
+ cfg.save_tmp_dump = os.path.join("exps", "save_tmp", _relative_path)
173
+ cfg.image_dump = os.path.join("exps", "images", _relative_path)
174
+ cfg.video_dump = os.path.join("exps", "videos", _relative_path) # output path
175
+
176
+ if args.infer is not None:
177
+ cfg_infer = OmegaConf.load(args.infer)
178
+ cfg.merge_with(cfg_infer)
179
+ cfg.setdefault(
180
+ "save_tmp_dump", os.path.join("exps", cli_cfg.model_name, "save_tmp")
181
+ )
182
+ cfg.setdefault("image_dump", os.path.join("exps", cli_cfg.model_name, "images"))
183
+ cfg.setdefault(
184
+ "video_dump", os.path.join("dumps", cli_cfg.model_name, "videos")
185
+ )
186
+ cfg.setdefault("mesh_dump", os.path.join("dumps", cli_cfg.model_name, "meshes"))
187
+
188
+ cfg.motion_video_read_fps = 30
189
+ cfg.merge_with(cli_cfg)
190
+
191
+ cfg.setdefault("logger", "INFO")
192
+
193
+ assert cfg.model_name is not None, "model_name is required"
194
+
195
+ return cfg, cfg_train
196
+
197
+
198
+ def create_zip_archive(output_zip='runtime_data/h5_render_data.zip', base_vid="nice", in_fd="./runtime_data"):
199
+ flame_params_pth = os.path.join("./assets/sample_motion/export", base_vid, "flame_params.json")
200
+ file_lst = [
201
+ f'{in_fd}/lbs_weight_20k.json', f'{in_fd}/offset.ply', f'{in_fd}/skin.glb',
202
+ f'{in_fd}/vertex_order.json', f'{in_fd}/bone_tree.json',
203
+ flame_params_pth
204
+ ]
205
+ try:
206
+ # Create a new ZIP file in write mode
207
+ with zipfile.ZipFile(output_zip, 'w') as zipf:
208
+ # List all files in the specified directory
209
+ for file_path in file_lst:
210
+ zipf.write(file_path, arcname=os.path.join("h5_render_data", os.path.basename(file_path)))
211
+ print(f"Archive created successfully: {output_zip}")
212
+ except Exception as e:
213
+ print(f"An error occurred: {e}")
214
+
215
+
216
+ def demo_lam(flametracking, lam, cfg):
217
+
218
+ # @spaces.GPU(duration=80)
219
+ def core_fn(image_path: str, video_params, working_dir):
220
+ image_raw = os.path.join(working_dir.name, "raw.png")
221
+ with Image.open(image_path).convert('RGB') as img:
222
+ img.save(image_raw)
223
+
224
+ base_vid = os.path.basename(video_params).split(".")[0]
225
+ flame_params_dir = os.path.join("./assets/sample_motion/export", base_vid, "flame_param")
226
+ base_iid = os.path.basename(image_path).split('.')[0]
227
+ image_path = os.path.join("./assets/sample_input", base_iid, "images/00000_00.png")
228
+
229
+ dump_video_path = os.path.join(working_dir.name, "output.mp4")
230
+ dump_image_path = os.path.join(working_dir.name, "output.png")
231
+
232
+ # prepare dump paths
233
+ omit_prefix = os.path.dirname(image_raw)
234
+ image_name = os.path.basename(image_raw)
235
+ uid = image_name.split(".")[0]
236
+ subdir_path = os.path.dirname(image_raw).replace(omit_prefix, "")
237
+ subdir_path = (
238
+ subdir_path[1:] if subdir_path.startswith("/") else subdir_path
239
+ )
240
+ print("subdir_path and uid:", subdir_path, uid)
241
+
242
+ motion_seqs_dir = flame_params_dir
243
+
244
+ dump_image_dir = os.path.dirname(dump_image_path)
245
+ os.makedirs(dump_image_dir, exist_ok=True)
246
+
247
+ print(image_raw, motion_seqs_dir, dump_image_dir, dump_video_path)
248
+
249
+ dump_tmp_dir = dump_image_dir
250
+
251
+ if os.path.exists(dump_video_path):
252
+ return dump_image_path, dump_video_path
253
+
254
+ motion_img_need_mask = cfg.get("motion_img_need_mask", False) # False
255
+ vis_motion = cfg.get("vis_motion", False) # False
256
+
257
+ # preprocess input image: segmentation, flame params estimation
258
+ return_code = flametracking.preprocess(image_raw)
259
+ assert (return_code == 0), "flametracking preprocess failed!"
260
+ return_code = flametracking.optimize()
261
+ assert (return_code == 0), "flametracking optimize failed!"
262
+ return_code, output_dir = flametracking.export()
263
+ assert (return_code == 0), "flametracking export failed!"
264
+
265
+ image_path = os.path.join(output_dir, "images/00000_00.png")
266
+ mask_path = os.path.join(output_dir, "fg_masks/00000_00.png")
267
+ print("image_path:", image_path, "\n"+"mask_path:", mask_path)
268
+
269
+ aspect_standard = 1.0/1.0
270
+ source_size = cfg.source_size
271
+ render_size = cfg.render_size
272
+ render_fps = 30
273
+ # prepare reference image
274
+ image, _, _, shape_param = preprocess_image(image_path, mask_path=mask_path, intr=None, pad_ratio=0, bg_color=1.,
275
+ max_tgt_size=None, aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1.0],
276
+ render_tgt_size=source_size, multiply=14, need_mask=True, get_shape_param=True)
277
+
278
+ # save masked image for vis
279
+ save_ref_img_path = os.path.join(dump_tmp_dir, "output.png")
280
+ vis_ref_img = (image[0].permute(1, 2, 0).cpu().detach().numpy() * 255).astype(np.uint8)
281
+ Image.fromarray(vis_ref_img).save(save_ref_img_path)
282
+
283
+ # prepare motion seq
284
+ src = image_path.split('/')[-3]
285
+ driven = motion_seqs_dir.split('/')[-2]
286
+ src_driven = [src, driven]
287
+ motion_seq = prepare_motion_seqs(motion_seqs_dir, None, save_root=dump_tmp_dir, fps=render_fps,
288
+ bg_color=1., aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1,0],
289
+ render_image_res=render_size, multiply=16,
290
+ need_mask=motion_img_need_mask, vis_motion=vis_motion,
291
+ shape_param=shape_param, test_sample=False, cross_id=False, src_driven=src_driven)
292
+
293
+ # start inference
294
+ motion_seq["flame_params"]["betas"] = shape_param.unsqueeze(0)
295
+ device, dtype = "cuda", torch.float32
296
+ print("start to inference...................")
297
+ with torch.no_grad():
298
+ # TODO check device and dtype
299
+ res = lam.infer_single_view(image.unsqueeze(0).to(device, dtype), None, None,
300
+ render_c2ws=motion_seq["render_c2ws"].to(device),
301
+ render_intrs=motion_seq["render_intrs"].to(device),
302
+ render_bg_colors=motion_seq["render_bg_colors"].to(device),
303
+ flame_params={k:v.to(device) for k, v in motion_seq["flame_params"].items()})
304
+
305
+ # save h5 rendering info
306
+ if h5_rendering:
307
+ res['cano_gs_lst'][0].save_ply(os.path.join(working_dir.name, "offset.ply"), rgb2sh=False, offset2xyz=True)
308
+
309
+ h5_fd = working_dir.name
310
+ lam.renderer.flame_model.save_h5_info(shape_param.unsqueeze(0).cuda(), fd=h5_fd)
311
+ res['cano_gs_lst'][0].save_ply(os.path.join(h5_fd, "offset.ply"), rgb2sh=False, offset2xyz=True)
312
+ cmd = do_softlink(h5_fd)
313
+ cmd = "thirdparties/blender/blender --background --python 'tools/generateGLBWithBlender_v2.py'"
314
+ os.system(cmd)
315
+ output_zip = os.path.join(h5_fd, "h5_render_data.zip")
316
+ create_zip_archive(output_zip=output_zip, base_vid=base_vid, in_fd=h5_fd)
317
+
318
+
319
+ rgb = res["comp_rgb"].detach().cpu().numpy() # [Nv, H, W, 3], 0-1
320
+ mask = res["comp_mask"].detach().cpu().numpy() # [Nv, H, W, 3], 0-1
321
+ mask[mask < 0.5] = 0.0
322
+ rgb = rgb * mask + (1 - mask) * 1
323
+ rgb = (np.clip(rgb, 0, 1.0) * 255).astype(np.uint8)
324
+ if vis_motion:
325
+ vis_ref_img = np.tile(
326
+ cv2.resize(vis_ref_img, (rgb[0].shape[1], rgb[0].shape[0]), interpolation=cv2.INTER_AREA)[None, :, :, :],
327
+ (rgb.shape[0], 1, 1, 1),
328
+ )
329
+ rgb = np.concatenate([vis_ref_img, rgb, motion_seq["vis_motion_render"]], axis=2)
330
+
331
+ os.makedirs(os.path.dirname(dump_video_path), exist_ok=True)
332
+
333
+ save_images2video(rgb, dump_video_path, render_fps)
334
+ audio_path = os.path.join("./assets/sample_motion/export", base_vid, base_vid+".wav")
335
+ dump_video_path_wa = dump_video_path.replace(".mp4", "_audio.mp4")
336
+ add_audio_to_video(dump_video_path, dump_video_path_wa, audio_path)
337
+
338
+ return dump_image_path, dump_video_path_wa
339
+
340
+ with gr.Blocks(analytics_enabled=False) as demo:
341
+
342
+ logo_url = './assets/images/logo.jpeg'
343
+ logo_base64 = get_image_base64(logo_url)
344
+ gr.HTML(f"""
345
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
346
+ <div>
347
+ <h1> <img src="{logo_base64}" style='height:35px; display:inline-block;'/> Large Avatar Model for One-shot Animatable Gaussian Head</h1>
348
+ </div>
349
+ </div>
350
+ """)
351
+ gr.HTML(
352
+ """
353
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center; margin: 20px; gap: 10px;">
354
+ <a class="flex-item" href="https://arxiv.org/abs/2502.17796" target="_blank">
355
+ <img src="https://img.shields.io/badge/Paper-arXiv-darkred.svg" alt="arXiv Paper">
356
+ </a>
357
+ <a class="flex-item" href="https://aigc3d.github.io/projects/LAM/" target="_blank">
358
+ <img src="https://img.shields.io/badge/Project-LAM-blue" alt="Project Page">
359
+ </a>
360
+ <a class="flex-item" href="https://github.com/aigc3d/LAM" target="_blank">
361
+ <img src="https://img.shields.io/github/stars/aigc3d/LAM?label=Github%20★&logo=github&color=C8C" alt="badge-github-stars">
362
+ </a>
363
+ <a class="flex-item" href="https://youtu.be/FrfE3RYSKhk" target="_blank">
364
+ <img src="https://img.shields.io/badge/Youtube-Video-red.svg" alt="Video">
365
+ </a>
366
+ </div>
367
+ """
368
+ )
369
+
370
+ gr.HTML("""<div style="margin-top: -10px">
371
+ <p style="margin: 4px 0; line-height: 1.2"><h4 style="color: red; margin: 2px 0">Notes1: Inputing front-face images or face orientation close to the driven signal gets better results.</h4></p>
372
+ <p style="margin: 4px 0; line-height: 1.2"><h4 style="color: red; margin: 2px 0">Notes2: Due to computational constraints with Hugging Face's ZeroGPU infrastructure, video generation requires ~1 minute per instance.</h4></p>
373
+ <p style="margin: 4px 0; line-height: 1.2"><h4 style="color: red; margin: 2px 0">Notes3: Using LAM-20K model (lower quality than premium LAM-80K) to mitigate processing latency.</h4></p>
374
+ </div>""")
375
+
376
+ # DISPLAY
377
+ with gr.Row():
378
+
379
+ with gr.Column(variant='panel', scale=1):
380
+ with gr.Tabs(elem_id='lam_input_image'):
381
+ with gr.TabItem('Input Image'):
382
+ with gr.Row():
383
+ input_image = gr.Image(label='Input Image',
384
+ image_mode='RGB',
385
+ height=480,
386
+ width=270,
387
+ sources='upload',
388
+ type='filepath',
389
+ elem_id='content_image')
390
+ # EXAMPLES
391
+ with gr.Row():
392
+ examples = [
393
+ ['assets/sample_input/messi.png'],
394
+ ['assets/sample_input/status.png'],
395
+ ['assets/sample_input/james.png'],
396
+ ['assets/sample_input/cluo.jpg'],
397
+ ['assets/sample_input/dufu.jpg'],
398
+ ['assets/sample_input/libai.jpg'],
399
+ ['assets/sample_input/barbara.jpg'],
400
+ ['assets/sample_input/pop.png'],
401
+ ['assets/sample_input/musk.jpg'],
402
+ ['assets/sample_input/speed.jpg'],
403
+ ['assets/sample_input/zhouxingchi.jpg'],
404
+ ]
405
+ gr.Examples(
406
+ examples=examples,
407
+ inputs=[input_image],
408
+ examples_per_page=20
409
+ )
410
+
411
+ with gr.Column():
412
+ with gr.Tabs(elem_id='lam_input_video'):
413
+ with gr.TabItem('Input Video'):
414
+ with gr.Row():
415
+ video_input = gr.Video(label='Input Video',
416
+ height=480,
417
+ width=270,
418
+ interactive=False)
419
+
420
+ examples = ['./assets/sample_motion/export/Speeding_Scandal/Speeding_Scandal.mp4',
421
+ './assets/sample_motion/export/Look_In_My_Eyes/Look_In_My_Eyes.mp4',
422
+ './assets/sample_motion/export/D_ANgelo_Dinero/D_ANgelo_Dinero.mp4',
423
+ './assets/sample_motion/export/Michael_Wayne_Rosen/Michael_Wayne_Rosen.mp4',
424
+ './assets/sample_motion/export/I_Am_Iron_Man/I_Am_Iron_Man.mp4',
425
+ './assets/sample_motion/export/Anti_Drugs/Anti_Drugs.mp4',
426
+ './assets/sample_motion/export/Pen_Pineapple_Apple_Pen/Pen_Pineapple_Apple_Pen.mp4',
427
+ './assets/sample_motion/export/Joe_Biden/Joe_Biden.mp4',
428
+ './assets/sample_motion/export/Donald_Trump/Donald_Trump.mp4',
429
+ './assets/sample_motion/export/Taylor_Swift/Taylor_Swift.mp4',
430
+ './assets/sample_motion/export/GEM/GEM.mp4',
431
+ './assets/sample_motion/export/The_Shawshank_Redemption/The_Shawshank_Redemption.mp4'
432
+ ]
433
+ print("Video example list {}".format(examples))
434
+
435
+ gr.Examples(
436
+ examples=examples,
437
+ inputs=[video_input],
438
+ examples_per_page=20,
439
+ )
440
+ with gr.Column(variant='panel', scale=1):
441
+ with gr.Tabs(elem_id='lam_processed_image'):
442
+ with gr.TabItem('Processed Image'):
443
+ with gr.Row():
444
+ processed_image = gr.Image(
445
+ label='Processed Image',
446
+ image_mode='RGBA',
447
+ type='filepath',
448
+ elem_id='processed_image',
449
+ height=480,
450
+ width=270,
451
+ interactive=False)
452
+
453
+ with gr.Column(variant='panel', scale=1):
454
+ with gr.Tabs(elem_id='lam_render_video'):
455
+ with gr.TabItem('Rendered Video'):
456
+ with gr.Row():
457
+ output_video = gr.Video(label='Rendered Video',
458
+ format='mp4',
459
+ height=480,
460
+ width=270,
461
+ autoplay=True)
462
+
463
+ # SETTING
464
+ with gr.Row():
465
+ with gr.Column(variant='panel', scale=1):
466
+ submit = gr.Button('Generate',
467
+ elem_id='lam_generate',
468
+ variant='primary')
469
+
470
+ if h5_rendering:
471
+ gr.HTML(f"""
472
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
473
+ <div>
474
+ <h2> Cross-platform H5 Rendering</h2>
475
+ </div>
476
+ </div>
477
+ """)
478
+ gr.set_static_paths("runtime_data/")
479
+ assetPrefix = 'gradio_api/file=runtime_data/'
480
+ with gr.Row():
481
+ gs = gaussian_render(width = 300, height = 400, assets = assetPrefix + 'h5_render_data.zip')
482
+
483
+ working_dir = gr.State()
484
+ submit.click(
485
+ fn=assert_input_image,
486
+ inputs=[input_image],
487
+ queue=False,
488
+ ).success(
489
+ fn=prepare_working_dir,
490
+ outputs=[working_dir],
491
+ queue=False,
492
+ ).success(
493
+ fn=core_fn,
494
+ inputs=[input_image, video_input,
495
+ working_dir],
496
+ outputs=[processed_image, output_video],
497
+ ).success(
498
+ doRender,
499
+ inputs=[working_dir],
500
+ js='''() => window.start()'''
501
+ )
502
+
503
+ demo.queue()
504
+ demo.launch()
505
+
506
+
507
+ def _build_model(cfg):
508
+ from lam.models import ModelLAM
509
+ from safetensors.torch import load_file
510
+
511
+ model = ModelLAM(**cfg.model)
512
+ resume = os.path.join(cfg.model_name, "model.safetensors")
513
+ print("="*100)
514
+ print("loading pretrained weight from:", resume)
515
+ if resume.endswith('safetensors'):
516
+ ckpt = load_file(resume, device='cpu')
517
+ else:
518
+ ckpt = torch.load(resume, map_location='cpu')
519
+ state_dict = model.state_dict()
520
+ for k, v in ckpt.items():
521
+ if k in state_dict:
522
+ if state_dict[k].shape == v.shape:
523
+ state_dict[k].copy_(v)
524
+ else:
525
+ print(f"WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.")
526
+ else:
527
+ print(f"WARN] unexpected param {k}: {v.shape}")
528
+ print("finish loading pretrained weight from:", resume)
529
+ print("="*100)
530
+ return model
531
+
532
+
533
+ def launch_gradio_app():
534
+
535
+ os.environ.update({
536
+ 'APP_ENABLED': '1',
537
+ 'APP_MODEL_NAME':
538
+ './model_zoo/lam_models/releases/lam/lam-20k/step_045500/',
539
+ 'APP_INFER': './configs/inference/lam-20k-8gpu.yaml',
540
+ 'APP_TYPE': 'infer.lam',
541
+ 'NUMBA_THREADING_LAYER': 'omp',
542
+ })
543
+
544
+ cfg, _ = parse_configs()
545
+ lam = _build_model(cfg)
546
+ lam.to('cuda')
547
+
548
+ flametracking = FlameTrackingSingleImage(output_dir='tracking_output',
549
+ alignment_model_path='./model_zoo/flame_tracking_models/68_keypoints_model.pkl',
550
+ vgghead_model_path='./model_zoo/flame_tracking_models/vgghead/vgg_heads_l.trcd',
551
+ human_matting_path='./model_zoo/flame_tracking_models/matting/stylematte_synth.pt',
552
+ facebox_model_path='./model_zoo/flame_tracking_models/FaceBoxesV2.pth',
553
+ detect_iris_landmarks=True)
554
+
555
+ demo_lam(flametracking, lam, cfg)
556
+
557
+
558
+ if __name__ == '__main__':
559
+ # launch_env_not_compile_with_cuda()
560
+ launch_gradio_app()
assets/images/logo.jpeg ADDED
assets/images/teaser.jpg ADDED

Git LFS Details

  • SHA256: 6a295a02d73bf8619a1e6ef6d38eea5e1371c6a9c729bb4b9d4ae3e36a5517e8
  • Pointer size: 131 Bytes
  • Size of remote file: 669 kB
configs/inference/lam-20k-8gpu.yaml ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ experiment:
3
+ type: lam
4
+ seed: 42
5
+ parent: lam
6
+ child: lam_20k
7
+ model:
8
+ # image encoder
9
+ encoder_type: "dinov2_fusion"
10
+ encoder_model_name: "dinov2_vitl14_reg"
11
+ encoder_feat_dim: 1024
12
+ encoder_freeze: false
13
+
14
+ # points embeddings
15
+ latent_query_points_type: "e2e_flame"
16
+ pcl_dim: 1024
17
+
18
+ # transformer
19
+ transformer_type: "sd3_cond"
20
+ transformer_heads: 16
21
+ transformer_dim: 1024
22
+ transformer_layers: 10
23
+ tf_grad_ckpt: true
24
+ encoder_grad_ckpt: true
25
+
26
+ # for gs renderer
27
+ human_model_path: "./model_zoo/human_parametric_models"
28
+ flame_subdivide_num: 1
29
+ flame_type: "flame"
30
+ gs_query_dim: 1024
31
+ gs_use_rgb: True
32
+ gs_sh: 3
33
+ gs_mlp_network_config:
34
+ n_neurons: 512
35
+ n_hidden_layers: 2
36
+ activation: silu
37
+ gs_xyz_offset_max_step: 0.2
38
+ gs_clip_scaling: 0.01
39
+ scale_sphere: false
40
+
41
+ expr_param_dim: 10
42
+ shape_param_dim: 10
43
+ add_teeth: false
44
+
45
+ fix_opacity: false
46
+ fix_rotation: false
47
+
48
+ has_disc: false
49
+
50
+ teeth_bs_flag: false
51
+ oral_mesh_flag: false
52
+
53
+ dataset:
54
+ subsets:
55
+ - name: video_head
56
+ root_dirs: "./train_data/vfhq_vhap_nooffset/export"
57
+ meta_path:
58
+ train: "./train_data/vfhq_vhap_nooffset/label/valid_id_train_list.json"
59
+ val: "./train_data/vfhq_vhap_nooffset/label/valid_id_val_list.json"
60
+ sample_rate: 1.0
61
+ sample_side_views: 7
62
+ sample_aug_views: 0
63
+ source_image_res: 512
64
+ render_image:
65
+ low: 512
66
+ high: 512
67
+ region: null
68
+ num_train_workers: 4
69
+ num_val_workers: 2
70
+ pin_mem: true
71
+ repeat_num: 1
72
+ gaga_track_type: "vfhq"
73
+
74
+ train:
75
+ mixed_precision: bf16 # REPLACE THIS BASED ON GPU TYPE
76
+ find_unused_parameters: false
77
+ loss:
78
+ pixel_weight: 0.0
79
+ pixel_loss_fn: "mse"
80
+ crop_face_weight: 0.
81
+ crop_mouth_weight: 0.
82
+ crop_eye_weight: 0.
83
+ masked_pixel_weight: 1.0
84
+ perceptual_weight: 1.0
85
+ tv_weight: -1
86
+ mask_weight: 0:1.0:0.5:10000
87
+ offset_reg_weight: 0.1
88
+ optim:
89
+ lr: 4e-4
90
+ weight_decay: 0.05
91
+ beta1: 0.9
92
+ beta2: 0.95
93
+ clip_grad_norm: 1.0
94
+ scheduler:
95
+ type: cosine
96
+ warmup_real_iters: 3000
97
+ batch_size: 4 # REPLACE THIS (PER GPU)
98
+ accum_steps: 1 # REPLACE THIS
99
+ epochs: 100 # REPLACE THIS
100
+ debug_global_steps: null
101
+ resume: ""
102
+
103
+ val:
104
+ batch_size: 2
105
+ global_step_period: 500
106
+ debug_batches: 10
107
+
108
+ saver:
109
+ auto_resume: true
110
+ load_model: null
111
+ checkpoint_root: ./exps/checkpoints
112
+ checkpoint_global_steps: 500
113
+ checkpoint_keep_level: 5
114
+
115
+ logger:
116
+ stream_level: WARNING
117
+ log_level: INFO
118
+ log_root: ./exps/logs
119
+ tracker_root: ./exps/trackers
120
+ enable_profiler: false
121
+ trackers:
122
+ - tensorboard
123
+ image_monitor:
124
+ train_global_steps: 500
125
+ samples_per_log: 4
126
+
127
+ compile:
128
+ suppress_errors: true
129
+ print_specializations: true
130
+ disable: true
configs/stylematte_config.json ADDED
@@ -0,0 +1,2311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "activation_function": "relu",
4
+ "architectures": [
5
+ "Mask2FormerForUniversalSegmentation"
6
+ ],
7
+ "backbone_config": {
8
+ "_name_or_path": "",
9
+ "add_cross_attention": false,
10
+ "architectures": [
11
+ "SwinForImageClassification"
12
+ ],
13
+ "attention_probs_dropout_prob": 0.0,
14
+ "bad_words_ids": null,
15
+ "begin_suppress_tokens": null,
16
+ "bos_token_id": null,
17
+ "chunk_size_feed_forward": 0,
18
+ "cross_attention_hidden_size": null,
19
+ "decoder_start_token_id": null,
20
+ "depths": [
21
+ 2,
22
+ 2,
23
+ 6,
24
+ 2
25
+ ],
26
+ "diversity_penalty": 0.0,
27
+ "do_sample": false,
28
+ "drop_path_rate": 0.3,
29
+ "early_stopping": false,
30
+ "embed_dim": 96,
31
+ "encoder_no_repeat_ngram_size": 0,
32
+ "encoder_stride": 32,
33
+ "eos_token_id": null,
34
+ "exponential_decay_length_penalty": null,
35
+ "finetuning_task": null,
36
+ "forced_bos_token_id": null,
37
+ "forced_eos_token_id": null,
38
+ "hidden_act": "gelu",
39
+ "hidden_dropout_prob": 0.0,
40
+ "hidden_size": 768,
41
+ "id2label": {
42
+ "0": "tench, Tinca tinca",
43
+ "1": "goldfish, Carassius auratus",
44
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
45
+ "3": "tiger shark, Galeocerdo cuvieri",
46
+ "4": "hammerhead, hammerhead shark",
47
+ "5": "electric ray, crampfish, numbfish, torpedo",
48
+ "6": "stingray",
49
+ "7": "cock",
50
+ "8": "hen",
51
+ "9": "ostrich, Struthio camelus",
52
+ "10": "brambling, Fringilla montifringilla",
53
+ "11": "goldfinch, Carduelis carduelis",
54
+ "12": "house finch, linnet, Carpodacus mexicanus",
55
+ "13": "junco, snowbird",
56
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
57
+ "15": "robin, American robin, Turdus migratorius",
58
+ "16": "bulbul",
59
+ "17": "jay",
60
+ "18": "magpie",
61
+ "19": "chickadee",
62
+ "20": "water ouzel, dipper",
63
+ "21": "kite",
64
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
65
+ "23": "vulture",
66
+ "24": "great grey owl, great gray owl, Strix nebulosa",
67
+ "25": "European fire salamander, Salamandra salamandra",
68
+ "26": "common newt, Triturus vulgaris",
69
+ "27": "eft",
70
+ "28": "spotted salamander, Ambystoma maculatum",
71
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
72
+ "30": "bullfrog, Rana catesbeiana",
73
+ "31": "tree frog, tree-frog",
74
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
75
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
76
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
77
+ "35": "mud turtle",
78
+ "36": "terrapin",
79
+ "37": "box turtle, box tortoise",
80
+ "38": "banded gecko",
81
+ "39": "common iguana, iguana, Iguana iguana",
82
+ "40": "American chameleon, anole, Anolis carolinensis",
83
+ "41": "whiptail, whiptail lizard",
84
+ "42": "agama",
85
+ "43": "frilled lizard, Chlamydosaurus kingi",
86
+ "44": "alligator lizard",
87
+ "45": "Gila monster, Heloderma suspectum",
88
+ "46": "green lizard, Lacerta viridis",
89
+ "47": "African chameleon, Chamaeleo chamaeleon",
90
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
91
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
92
+ "50": "American alligator, Alligator mississipiensis",
93
+ "51": "triceratops",
94
+ "52": "thunder snake, worm snake, Carphophis amoenus",
95
+ "53": "ringneck snake, ring-necked snake, ring snake",
96
+ "54": "hognose snake, puff adder, sand viper",
97
+ "55": "green snake, grass snake",
98
+ "56": "king snake, kingsnake",
99
+ "57": "garter snake, grass snake",
100
+ "58": "water snake",
101
+ "59": "vine snake",
102
+ "60": "night snake, Hypsiglena torquata",
103
+ "61": "boa constrictor, Constrictor constrictor",
104
+ "62": "rock python, rock snake, Python sebae",
105
+ "63": "Indian cobra, Naja naja",
106
+ "64": "green mamba",
107
+ "65": "sea snake",
108
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
109
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
110
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
111
+ "69": "trilobite",
112
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
113
+ "71": "scorpion",
114
+ "72": "black and gold garden spider, Argiope aurantia",
115
+ "73": "barn spider, Araneus cavaticus",
116
+ "74": "garden spider, Aranea diademata",
117
+ "75": "black widow, Latrodectus mactans",
118
+ "76": "tarantula",
119
+ "77": "wolf spider, hunting spider",
120
+ "78": "tick",
121
+ "79": "centipede",
122
+ "80": "black grouse",
123
+ "81": "ptarmigan",
124
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
125
+ "83": "prairie chicken, prairie grouse, prairie fowl",
126
+ "84": "peacock",
127
+ "85": "quail",
128
+ "86": "partridge",
129
+ "87": "African grey, African gray, Psittacus erithacus",
130
+ "88": "macaw",
131
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
132
+ "90": "lorikeet",
133
+ "91": "coucal",
134
+ "92": "bee eater",
135
+ "93": "hornbill",
136
+ "94": "hummingbird",
137
+ "95": "jacamar",
138
+ "96": "toucan",
139
+ "97": "drake",
140
+ "98": "red-breasted merganser, Mergus serrator",
141
+ "99": "goose",
142
+ "100": "black swan, Cygnus atratus",
143
+ "101": "tusker",
144
+ "102": "echidna, spiny anteater, anteater",
145
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
146
+ "104": "wallaby, brush kangaroo",
147
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
148
+ "106": "wombat",
149
+ "107": "jellyfish",
150
+ "108": "sea anemone, anemone",
151
+ "109": "brain coral",
152
+ "110": "flatworm, platyhelminth",
153
+ "111": "nematode, nematode worm, roundworm",
154
+ "112": "conch",
155
+ "113": "snail",
156
+ "114": "slug",
157
+ "115": "sea slug, nudibranch",
158
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
159
+ "117": "chambered nautilus, pearly nautilus, nautilus",
160
+ "118": "Dungeness crab, Cancer magister",
161
+ "119": "rock crab, Cancer irroratus",
162
+ "120": "fiddler crab",
163
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
164
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
165
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
166
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
167
+ "125": "hermit crab",
168
+ "126": "isopod",
169
+ "127": "white stork, Ciconia ciconia",
170
+ "128": "black stork, Ciconia nigra",
171
+ "129": "spoonbill",
172
+ "130": "flamingo",
173
+ "131": "little blue heron, Egretta caerulea",
174
+ "132": "American egret, great white heron, Egretta albus",
175
+ "133": "bittern",
176
+ "134": "crane",
177
+ "135": "limpkin, Aramus pictus",
178
+ "136": "European gallinule, Porphyrio porphyrio",
179
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
180
+ "138": "bustard",
181
+ "139": "ruddy turnstone, Arenaria interpres",
182
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
183
+ "141": "redshank, Tringa totanus",
184
+ "142": "dowitcher",
185
+ "143": "oystercatcher, oyster catcher",
186
+ "144": "pelican",
187
+ "145": "king penguin, Aptenodytes patagonica",
188
+ "146": "albatross, mollymawk",
189
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
190
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
191
+ "149": "dugong, Dugong dugon",
192
+ "150": "sea lion",
193
+ "151": "Chihuahua",
194
+ "152": "Japanese spaniel",
195
+ "153": "Maltese dog, Maltese terrier, Maltese",
196
+ "154": "Pekinese, Pekingese, Peke",
197
+ "155": "Shih-Tzu",
198
+ "156": "Blenheim spaniel",
199
+ "157": "papillon",
200
+ "158": "toy terrier",
201
+ "159": "Rhodesian ridgeback",
202
+ "160": "Afghan hound, Afghan",
203
+ "161": "basset, basset hound",
204
+ "162": "beagle",
205
+ "163": "bloodhound, sleuthhound",
206
+ "164": "bluetick",
207
+ "165": "black-and-tan coonhound",
208
+ "166": "Walker hound, Walker foxhound",
209
+ "167": "English foxhound",
210
+ "168": "redbone",
211
+ "169": "borzoi, Russian wolfhound",
212
+ "170": "Irish wolfhound",
213
+ "171": "Italian greyhound",
214
+ "172": "whippet",
215
+ "173": "Ibizan hound, Ibizan Podenco",
216
+ "174": "Norwegian elkhound, elkhound",
217
+ "175": "otterhound, otter hound",
218
+ "176": "Saluki, gazelle hound",
219
+ "177": "Scottish deerhound, deerhound",
220
+ "178": "Weimaraner",
221
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
222
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
223
+ "181": "Bedlington terrier",
224
+ "182": "Border terrier",
225
+ "183": "Kerry blue terrier",
226
+ "184": "Irish terrier",
227
+ "185": "Norfolk terrier",
228
+ "186": "Norwich terrier",
229
+ "187": "Yorkshire terrier",
230
+ "188": "wire-haired fox terrier",
231
+ "189": "Lakeland terrier",
232
+ "190": "Sealyham terrier, Sealyham",
233
+ "191": "Airedale, Airedale terrier",
234
+ "192": "cairn, cairn terrier",
235
+ "193": "Australian terrier",
236
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
237
+ "195": "Boston bull, Boston terrier",
238
+ "196": "miniature schnauzer",
239
+ "197": "giant schnauzer",
240
+ "198": "standard schnauzer",
241
+ "199": "Scotch terrier, Scottish terrier, Scottie",
242
+ "200": "Tibetan terrier, chrysanthemum dog",
243
+ "201": "silky terrier, Sydney silky",
244
+ "202": "soft-coated wheaten terrier",
245
+ "203": "West Highland white terrier",
246
+ "204": "Lhasa, Lhasa apso",
247
+ "205": "flat-coated retriever",
248
+ "206": "curly-coated retriever",
249
+ "207": "golden retriever",
250
+ "208": "Labrador retriever",
251
+ "209": "Chesapeake Bay retriever",
252
+ "210": "German short-haired pointer",
253
+ "211": "vizsla, Hungarian pointer",
254
+ "212": "English setter",
255
+ "213": "Irish setter, red setter",
256
+ "214": "Gordon setter",
257
+ "215": "Brittany spaniel",
258
+ "216": "clumber, clumber spaniel",
259
+ "217": "English springer, English springer spaniel",
260
+ "218": "Welsh springer spaniel",
261
+ "219": "cocker spaniel, English cocker spaniel, cocker",
262
+ "220": "Sussex spaniel",
263
+ "221": "Irish water spaniel",
264
+ "222": "kuvasz",
265
+ "223": "schipperke",
266
+ "224": "groenendael",
267
+ "225": "malinois",
268
+ "226": "briard",
269
+ "227": "kelpie",
270
+ "228": "komondor",
271
+ "229": "Old English sheepdog, bobtail",
272
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
273
+ "231": "collie",
274
+ "232": "Border collie",
275
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
276
+ "234": "Rottweiler",
277
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
278
+ "236": "Doberman, Doberman pinscher",
279
+ "237": "miniature pinscher",
280
+ "238": "Greater Swiss Mountain dog",
281
+ "239": "Bernese mountain dog",
282
+ "240": "Appenzeller",
283
+ "241": "EntleBucher",
284
+ "242": "boxer",
285
+ "243": "bull mastiff",
286
+ "244": "Tibetan mastiff",
287
+ "245": "French bulldog",
288
+ "246": "Great Dane",
289
+ "247": "Saint Bernard, St Bernard",
290
+ "248": "Eskimo dog, husky",
291
+ "249": "malamute, malemute, Alaskan malamute",
292
+ "250": "Siberian husky",
293
+ "251": "dalmatian, coach dog, carriage dog",
294
+ "252": "affenpinscher, monkey pinscher, monkey dog",
295
+ "253": "basenji",
296
+ "254": "pug, pug-dog",
297
+ "255": "Leonberg",
298
+ "256": "Newfoundland, Newfoundland dog",
299
+ "257": "Great Pyrenees",
300
+ "258": "Samoyed, Samoyede",
301
+ "259": "Pomeranian",
302
+ "260": "chow, chow chow",
303
+ "261": "keeshond",
304
+ "262": "Brabancon griffon",
305
+ "263": "Pembroke, Pembroke Welsh corgi",
306
+ "264": "Cardigan, Cardigan Welsh corgi",
307
+ "265": "toy poodle",
308
+ "266": "miniature poodle",
309
+ "267": "standard poodle",
310
+ "268": "Mexican hairless",
311
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
312
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
313
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
314
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
315
+ "273": "dingo, warrigal, warragal, Canis dingo",
316
+ "274": "dhole, Cuon alpinus",
317
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
318
+ "276": "hyena, hyaena",
319
+ "277": "red fox, Vulpes vulpes",
320
+ "278": "kit fox, Vulpes macrotis",
321
+ "279": "Arctic fox, white fox, Alopex lagopus",
322
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
323
+ "281": "tabby, tabby cat",
324
+ "282": "tiger cat",
325
+ "283": "Persian cat",
326
+ "284": "Siamese cat, Siamese",
327
+ "285": "Egyptian cat",
328
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
329
+ "287": "lynx, catamount",
330
+ "288": "leopard, Panthera pardus",
331
+ "289": "snow leopard, ounce, Panthera uncia",
332
+ "290": "jaguar, panther, Panthera onca, Felis onca",
333
+ "291": "lion, king of beasts, Panthera leo",
334
+ "292": "tiger, Panthera tigris",
335
+ "293": "cheetah, chetah, Acinonyx jubatus",
336
+ "294": "brown bear, bruin, Ursus arctos",
337
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
338
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
339
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
340
+ "298": "mongoose",
341
+ "299": "meerkat, mierkat",
342
+ "300": "tiger beetle",
343
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
344
+ "302": "ground beetle, carabid beetle",
345
+ "303": "long-horned beetle, longicorn, longicorn beetle",
346
+ "304": "leaf beetle, chrysomelid",
347
+ "305": "dung beetle",
348
+ "306": "rhinoceros beetle",
349
+ "307": "weevil",
350
+ "308": "fly",
351
+ "309": "bee",
352
+ "310": "ant, emmet, pismire",
353
+ "311": "grasshopper, hopper",
354
+ "312": "cricket",
355
+ "313": "walking stick, walkingstick, stick insect",
356
+ "314": "cockroach, roach",
357
+ "315": "mantis, mantid",
358
+ "316": "cicada, cicala",
359
+ "317": "leafhopper",
360
+ "318": "lacewing, lacewing fly",
361
+ "319": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
362
+ "320": "damselfly",
363
+ "321": "admiral",
364
+ "322": "ringlet, ringlet butterfly",
365
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
366
+ "324": "cabbage butterfly",
367
+ "325": "sulphur butterfly, sulfur butterfly",
368
+ "326": "lycaenid, lycaenid butterfly",
369
+ "327": "starfish, sea star",
370
+ "328": "sea urchin",
371
+ "329": "sea cucumber, holothurian",
372
+ "330": "wood rabbit, cottontail, cottontail rabbit",
373
+ "331": "hare",
374
+ "332": "Angora, Angora rabbit",
375
+ "333": "hamster",
376
+ "334": "porcupine, hedgehog",
377
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
378
+ "336": "marmot",
379
+ "337": "beaver",
380
+ "338": "guinea pig, Cavia cobaya",
381
+ "339": "sorrel",
382
+ "340": "zebra",
383
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
384
+ "342": "wild boar, boar, Sus scrofa",
385
+ "343": "warthog",
386
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
387
+ "345": "ox",
388
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
389
+ "347": "bison",
390
+ "348": "ram, tup",
391
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
392
+ "350": "ibex, Capra ibex",
393
+ "351": "hartebeest",
394
+ "352": "impala, Aepyceros melampus",
395
+ "353": "gazelle",
396
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
397
+ "355": "llama",
398
+ "356": "weasel",
399
+ "357": "mink",
400
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
401
+ "359": "black-footed ferret, ferret, Mustela nigripes",
402
+ "360": "otter",
403
+ "361": "skunk, polecat, wood pussy",
404
+ "362": "badger",
405
+ "363": "armadillo",
406
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
407
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
408
+ "366": "gorilla, Gorilla gorilla",
409
+ "367": "chimpanzee, chimp, Pan troglodytes",
410
+ "368": "gibbon, Hylobates lar",
411
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
412
+ "370": "guenon, guenon monkey",
413
+ "371": "patas, hussar monkey, Erythrocebus patas",
414
+ "372": "baboon",
415
+ "373": "macaque",
416
+ "374": "langur",
417
+ "375": "colobus, colobus monkey",
418
+ "376": "proboscis monkey, Nasalis larvatus",
419
+ "377": "marmoset",
420
+ "378": "capuchin, ringtail, Cebus capucinus",
421
+ "379": "howler monkey, howler",
422
+ "380": "titi, titi monkey",
423
+ "381": "spider monkey, Ateles geoffroyi",
424
+ "382": "squirrel monkey, Saimiri sciureus",
425
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
426
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
427
+ "385": "Indian elephant, Elephas maximus",
428
+ "386": "African elephant, Loxodonta africana",
429
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
430
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
431
+ "389": "barracouta, snoek",
432
+ "390": "eel",
433
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
434
+ "392": "rock beauty, Holocanthus tricolor",
435
+ "393": "anemone fish",
436
+ "394": "sturgeon",
437
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
438
+ "396": "lionfish",
439
+ "397": "puffer, pufferfish, blowfish, globefish",
440
+ "398": "abacus",
441
+ "399": "abaya",
442
+ "400": "academic gown, academic robe, judge's robe",
443
+ "401": "accordion, piano accordion, squeeze box",
444
+ "402": "acoustic guitar",
445
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
446
+ "404": "airliner",
447
+ "405": "airship, dirigible",
448
+ "406": "altar",
449
+ "407": "ambulance",
450
+ "408": "amphibian, amphibious vehicle",
451
+ "409": "analog clock",
452
+ "410": "apiary, bee house",
453
+ "411": "apron",
454
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
455
+ "413": "assault rifle, assault gun",
456
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
457
+ "415": "bakery, bakeshop, bakehouse",
458
+ "416": "balance beam, beam",
459
+ "417": "balloon",
460
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
461
+ "419": "Band Aid",
462
+ "420": "banjo",
463
+ "421": "bannister, banister, balustrade, balusters, handrail",
464
+ "422": "barbell",
465
+ "423": "barber chair",
466
+ "424": "barbershop",
467
+ "425": "barn",
468
+ "426": "barometer",
469
+ "427": "barrel, cask",
470
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
471
+ "429": "baseball",
472
+ "430": "basketball",
473
+ "431": "bassinet",
474
+ "432": "bassoon",
475
+ "433": "bathing cap, swimming cap",
476
+ "434": "bath towel",
477
+ "435": "bathtub, bathing tub, bath, tub",
478
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
479
+ "437": "beacon, lighthouse, beacon light, pharos",
480
+ "438": "beaker",
481
+ "439": "bearskin, busby, shako",
482
+ "440": "beer bottle",
483
+ "441": "beer glass",
484
+ "442": "bell cote, bell cot",
485
+ "443": "bib",
486
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
487
+ "445": "bikini, two-piece",
488
+ "446": "binder, ring-binder",
489
+ "447": "binoculars, field glasses, opera glasses",
490
+ "448": "birdhouse",
491
+ "449": "boathouse",
492
+ "450": "bobsled, bobsleigh, bob",
493
+ "451": "bolo tie, bolo, bola tie, bola",
494
+ "452": "bonnet, poke bonnet",
495
+ "453": "bookcase",
496
+ "454": "bookshop, bookstore, bookstall",
497
+ "455": "bottlecap",
498
+ "456": "bow",
499
+ "457": "bow tie, bow-tie, bowtie",
500
+ "458": "brass, memorial tablet, plaque",
501
+ "459": "brassiere, bra, bandeau",
502
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
503
+ "461": "breastplate, aegis, egis",
504
+ "462": "broom",
505
+ "463": "bucket, pail",
506
+ "464": "buckle",
507
+ "465": "bulletproof vest",
508
+ "466": "bullet train, bullet",
509
+ "467": "butcher shop, meat market",
510
+ "468": "cab, hack, taxi, taxicab",
511
+ "469": "caldron, cauldron",
512
+ "470": "candle, taper, wax light",
513
+ "471": "cannon",
514
+ "472": "canoe",
515
+ "473": "can opener, tin opener",
516
+ "474": "cardigan",
517
+ "475": "car mirror",
518
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
519
+ "477": "carpenter's kit, tool kit",
520
+ "478": "carton",
521
+ "479": "car wheel",
522
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
523
+ "481": "cassette",
524
+ "482": "cassette player",
525
+ "483": "castle",
526
+ "484": "catamaran",
527
+ "485": "CD player",
528
+ "486": "cello, violoncello",
529
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
530
+ "488": "chain",
531
+ "489": "chainlink fence",
532
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
533
+ "491": "chain saw, chainsaw",
534
+ "492": "chest",
535
+ "493": "chiffonier, commode",
536
+ "494": "chime, bell, gong",
537
+ "495": "china cabinet, china closet",
538
+ "496": "Christmas stocking",
539
+ "497": "church, church building",
540
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
541
+ "499": "cleaver, meat cleaver, chopper",
542
+ "500": "cliff dwelling",
543
+ "501": "cloak",
544
+ "502": "clog, geta, patten, sabot",
545
+ "503": "cocktail shaker",
546
+ "504": "coffee mug",
547
+ "505": "coffeepot",
548
+ "506": "coil, spiral, volute, whorl, helix",
549
+ "507": "combination lock",
550
+ "508": "computer keyboard, keypad",
551
+ "509": "confectionery, confectionary, candy store",
552
+ "510": "container ship, containership, container vessel",
553
+ "511": "convertible",
554
+ "512": "corkscrew, bottle screw",
555
+ "513": "cornet, horn, trumpet, trump",
556
+ "514": "cowboy boot",
557
+ "515": "cowboy hat, ten-gallon hat",
558
+ "516": "cradle",
559
+ "517": "crane",
560
+ "518": "crash helmet",
561
+ "519": "crate",
562
+ "520": "crib, cot",
563
+ "521": "Crock Pot",
564
+ "522": "croquet ball",
565
+ "523": "crutch",
566
+ "524": "cuirass",
567
+ "525": "dam, dike, dyke",
568
+ "526": "desk",
569
+ "527": "desktop computer",
570
+ "528": "dial telephone, dial phone",
571
+ "529": "diaper, nappy, napkin",
572
+ "530": "digital clock",
573
+ "531": "digital watch",
574
+ "532": "dining table, board",
575
+ "533": "dishrag, dishcloth",
576
+ "534": "dishwasher, dish washer, dishwashing machine",
577
+ "535": "disk brake, disc brake",
578
+ "536": "dock, dockage, docking facility",
579
+ "537": "dogsled, dog sled, dog sleigh",
580
+ "538": "dome",
581
+ "539": "doormat, welcome mat",
582
+ "540": "drilling platform, offshore rig",
583
+ "541": "drum, membranophone, tympan",
584
+ "542": "drumstick",
585
+ "543": "dumbbell",
586
+ "544": "Dutch oven",
587
+ "545": "electric fan, blower",
588
+ "546": "electric guitar",
589
+ "547": "electric locomotive",
590
+ "548": "entertainment center",
591
+ "549": "envelope",
592
+ "550": "espresso maker",
593
+ "551": "face powder",
594
+ "552": "feather boa, boa",
595
+ "553": "file, file cabinet, filing cabinet",
596
+ "554": "fireboat",
597
+ "555": "fire engine, fire truck",
598
+ "556": "fire screen, fireguard",
599
+ "557": "flagpole, flagstaff",
600
+ "558": "flute, transverse flute",
601
+ "559": "folding chair",
602
+ "560": "football helmet",
603
+ "561": "forklift",
604
+ "562": "fountain",
605
+ "563": "fountain pen",
606
+ "564": "four-poster",
607
+ "565": "freight car",
608
+ "566": "French horn, horn",
609
+ "567": "frying pan, frypan, skillet",
610
+ "568": "fur coat",
611
+ "569": "garbage truck, dustcart",
612
+ "570": "gasmask, respirator, gas helmet",
613
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
614
+ "572": "goblet",
615
+ "573": "go-kart",
616
+ "574": "golf ball",
617
+ "575": "golfcart, golf cart",
618
+ "576": "gondola",
619
+ "577": "gong, tam-tam",
620
+ "578": "gown",
621
+ "579": "grand piano, grand",
622
+ "580": "greenhouse, nursery, glasshouse",
623
+ "581": "grille, radiator grille",
624
+ "582": "grocery store, grocery, food market, market",
625
+ "583": "guillotine",
626
+ "584": "hair slide",
627
+ "585": "hair spray",
628
+ "586": "half track",
629
+ "587": "hammer",
630
+ "588": "hamper",
631
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
632
+ "590": "hand-held computer, hand-held microcomputer",
633
+ "591": "handkerchief, hankie, hanky, hankey",
634
+ "592": "hard disc, hard disk, fixed disk",
635
+ "593": "harmonica, mouth organ, harp, mouth harp",
636
+ "594": "harp",
637
+ "595": "harvester, reaper",
638
+ "596": "hatchet",
639
+ "597": "holster",
640
+ "598": "home theater, home theatre",
641
+ "599": "honeycomb",
642
+ "600": "hook, claw",
643
+ "601": "hoopskirt, crinoline",
644
+ "602": "horizontal bar, high bar",
645
+ "603": "horse cart, horse-cart",
646
+ "604": "hourglass",
647
+ "605": "iPod",
648
+ "606": "iron, smoothing iron",
649
+ "607": "jack-o'-lantern",
650
+ "608": "jean, blue jean, denim",
651
+ "609": "jeep, landrover",
652
+ "610": "jersey, T-shirt, tee shirt",
653
+ "611": "jigsaw puzzle",
654
+ "612": "jinrikisha, ricksha, rickshaw",
655
+ "613": "joystick",
656
+ "614": "kimono",
657
+ "615": "knee pad",
658
+ "616": "knot",
659
+ "617": "lab coat, laboratory coat",
660
+ "618": "ladle",
661
+ "619": "lampshade, lamp shade",
662
+ "620": "laptop, laptop computer",
663
+ "621": "lawn mower, mower",
664
+ "622": "lens cap, lens cover",
665
+ "623": "letter opener, paper knife, paperknife",
666
+ "624": "library",
667
+ "625": "lifeboat",
668
+ "626": "lighter, light, igniter, ignitor",
669
+ "627": "limousine, limo",
670
+ "628": "liner, ocean liner",
671
+ "629": "lipstick, lip rouge",
672
+ "630": "Loafer",
673
+ "631": "lotion",
674
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
675
+ "633": "loupe, jeweler's loupe",
676
+ "634": "lumbermill, sawmill",
677
+ "635": "magnetic compass",
678
+ "636": "mailbag, postbag",
679
+ "637": "mailbox, letter box",
680
+ "638": "maillot",
681
+ "639": "maillot, tank suit",
682
+ "640": "manhole cover",
683
+ "641": "maraca",
684
+ "642": "marimba, xylophone",
685
+ "643": "mask",
686
+ "644": "matchstick",
687
+ "645": "maypole",
688
+ "646": "maze, labyrinth",
689
+ "647": "measuring cup",
690
+ "648": "medicine chest, medicine cabinet",
691
+ "649": "megalith, megalithic structure",
692
+ "650": "microphone, mike",
693
+ "651": "microwave, microwave oven",
694
+ "652": "military uniform",
695
+ "653": "milk can",
696
+ "654": "minibus",
697
+ "655": "miniskirt, mini",
698
+ "656": "minivan",
699
+ "657": "missile",
700
+ "658": "mitten",
701
+ "659": "mixing bowl",
702
+ "660": "mobile home, manufactured home",
703
+ "661": "Model T",
704
+ "662": "modem",
705
+ "663": "monastery",
706
+ "664": "monitor",
707
+ "665": "moped",
708
+ "666": "mortar",
709
+ "667": "mortarboard",
710
+ "668": "mosque",
711
+ "669": "mosquito net",
712
+ "670": "motor scooter, scooter",
713
+ "671": "mountain bike, all-terrain bike, off-roader",
714
+ "672": "mountain tent",
715
+ "673": "mouse, computer mouse",
716
+ "674": "mousetrap",
717
+ "675": "moving van",
718
+ "676": "muzzle",
719
+ "677": "nail",
720
+ "678": "neck brace",
721
+ "679": "necklace",
722
+ "680": "nipple",
723
+ "681": "notebook, notebook computer",
724
+ "682": "obelisk",
725
+ "683": "oboe, hautboy, hautbois",
726
+ "684": "ocarina, sweet potato",
727
+ "685": "odometer, hodometer, mileometer, milometer",
728
+ "686": "oil filter",
729
+ "687": "organ, pipe organ",
730
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
731
+ "689": "overskirt",
732
+ "690": "oxcart",
733
+ "691": "oxygen mask",
734
+ "692": "packet",
735
+ "693": "paddle, boat paddle",
736
+ "694": "paddlewheel, paddle wheel",
737
+ "695": "padlock",
738
+ "696": "paintbrush",
739
+ "697": "pajama, pyjama, pj's, jammies",
740
+ "698": "palace",
741
+ "699": "panpipe, pandean pipe, syrinx",
742
+ "700": "paper towel",
743
+ "701": "parachute, chute",
744
+ "702": "parallel bars, bars",
745
+ "703": "park bench",
746
+ "704": "parking meter",
747
+ "705": "passenger car, coach, carriage",
748
+ "706": "patio, terrace",
749
+ "707": "pay-phone, pay-station",
750
+ "708": "pedestal, plinth, footstall",
751
+ "709": "pencil box, pencil case",
752
+ "710": "pencil sharpener",
753
+ "711": "perfume, essence",
754
+ "712": "Petri dish",
755
+ "713": "photocopier",
756
+ "714": "pick, plectrum, plectron",
757
+ "715": "pickelhaube",
758
+ "716": "picket fence, paling",
759
+ "717": "pickup, pickup truck",
760
+ "718": "pier",
761
+ "719": "piggy bank, penny bank",
762
+ "720": "pill bottle",
763
+ "721": "pillow",
764
+ "722": "ping-pong ball",
765
+ "723": "pinwheel",
766
+ "724": "pirate, pirate ship",
767
+ "725": "pitcher, ewer",
768
+ "726": "plane, carpenter's plane, woodworking plane",
769
+ "727": "planetarium",
770
+ "728": "plastic bag",
771
+ "729": "plate rack",
772
+ "730": "plow, plough",
773
+ "731": "plunger, plumber's helper",
774
+ "732": "Polaroid camera, Polaroid Land camera",
775
+ "733": "pole",
776
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
777
+ "735": "poncho",
778
+ "736": "pool table, billiard table, snooker table",
779
+ "737": "pop bottle, soda bottle",
780
+ "738": "pot, flowerpot",
781
+ "739": "potter's wheel",
782
+ "740": "power drill",
783
+ "741": "prayer rug, prayer mat",
784
+ "742": "printer",
785
+ "743": "prison, prison house",
786
+ "744": "projectile, missile",
787
+ "745": "projector",
788
+ "746": "puck, hockey puck",
789
+ "747": "punching bag, punch bag, punching ball, punchball",
790
+ "748": "purse",
791
+ "749": "quill, quill pen",
792
+ "750": "quilt, comforter, comfort, puff",
793
+ "751": "racer, race car, racing car",
794
+ "752": "racket, racquet",
795
+ "753": "radiator",
796
+ "754": "radio, wireless",
797
+ "755": "radio telescope, radio reflector",
798
+ "756": "rain barrel",
799
+ "757": "recreational vehicle, RV, R.V.",
800
+ "758": "reel",
801
+ "759": "reflex camera",
802
+ "760": "refrigerator, icebox",
803
+ "761": "remote control, remote",
804
+ "762": "restaurant, eating house, eating place, eatery",
805
+ "763": "revolver, six-gun, six-shooter",
806
+ "764": "rifle",
807
+ "765": "rocking chair, rocker",
808
+ "766": "rotisserie",
809
+ "767": "rubber eraser, rubber, pencil eraser",
810
+ "768": "rugby ball",
811
+ "769": "rule, ruler",
812
+ "770": "running shoe",
813
+ "771": "safe",
814
+ "772": "safety pin",
815
+ "773": "saltshaker, salt shaker",
816
+ "774": "sandal",
817
+ "775": "sarong",
818
+ "776": "sax, saxophone",
819
+ "777": "scabbard",
820
+ "778": "scale, weighing machine",
821
+ "779": "school bus",
822
+ "780": "schooner",
823
+ "781": "scoreboard",
824
+ "782": "screen, CRT screen",
825
+ "783": "screw",
826
+ "784": "screwdriver",
827
+ "785": "seat belt, seatbelt",
828
+ "786": "sewing machine",
829
+ "787": "shield, buckler",
830
+ "788": "shoe shop, shoe-shop, shoe store",
831
+ "789": "shoji",
832
+ "790": "shopping basket",
833
+ "791": "shopping cart",
834
+ "792": "shovel",
835
+ "793": "shower cap",
836
+ "794": "shower curtain",
837
+ "795": "ski",
838
+ "796": "ski mask",
839
+ "797": "sleeping bag",
840
+ "798": "slide rule, slipstick",
841
+ "799": "sliding door",
842
+ "800": "slot, one-armed bandit",
843
+ "801": "snorkel",
844
+ "802": "snowmobile",
845
+ "803": "snowplow, snowplough",
846
+ "804": "soap dispenser",
847
+ "805": "soccer ball",
848
+ "806": "sock",
849
+ "807": "solar dish, solar collector, solar furnace",
850
+ "808": "sombrero",
851
+ "809": "soup bowl",
852
+ "810": "space bar",
853
+ "811": "space heater",
854
+ "812": "space shuttle",
855
+ "813": "spatula",
856
+ "814": "speedboat",
857
+ "815": "spider web, spider's web",
858
+ "816": "spindle",
859
+ "817": "sports car, sport car",
860
+ "818": "spotlight, spot",
861
+ "819": "stage",
862
+ "820": "steam locomotive",
863
+ "821": "steel arch bridge",
864
+ "822": "steel drum",
865
+ "823": "stethoscope",
866
+ "824": "stole",
867
+ "825": "stone wall",
868
+ "826": "stopwatch, stop watch",
869
+ "827": "stove",
870
+ "828": "strainer",
871
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
872
+ "830": "stretcher",
873
+ "831": "studio couch, day bed",
874
+ "832": "stupa, tope",
875
+ "833": "submarine, pigboat, sub, U-boat",
876
+ "834": "suit, suit of clothes",
877
+ "835": "sundial",
878
+ "836": "sunglass",
879
+ "837": "sunglasses, dark glasses, shades",
880
+ "838": "sunscreen, sunblock, sun blocker",
881
+ "839": "suspension bridge",
882
+ "840": "swab, swob, mop",
883
+ "841": "sweatshirt",
884
+ "842": "swimming trunks, bathing trunks",
885
+ "843": "swing",
886
+ "844": "switch, electric switch, electrical switch",
887
+ "845": "syringe",
888
+ "846": "table lamp",
889
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
890
+ "848": "tape player",
891
+ "849": "teapot",
892
+ "850": "teddy, teddy bear",
893
+ "851": "television, television system",
894
+ "852": "tennis ball",
895
+ "853": "thatch, thatched roof",
896
+ "854": "theater curtain, theatre curtain",
897
+ "855": "thimble",
898
+ "856": "thresher, thrasher, threshing machine",
899
+ "857": "throne",
900
+ "858": "tile roof",
901
+ "859": "toaster",
902
+ "860": "tobacco shop, tobacconist shop, tobacconist",
903
+ "861": "toilet seat",
904
+ "862": "torch",
905
+ "863": "totem pole",
906
+ "864": "tow truck, tow car, wrecker",
907
+ "865": "toyshop",
908
+ "866": "tractor",
909
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
910
+ "868": "tray",
911
+ "869": "trench coat",
912
+ "870": "tricycle, trike, velocipede",
913
+ "871": "trimaran",
914
+ "872": "tripod",
915
+ "873": "triumphal arch",
916
+ "874": "trolleybus, trolley coach, trackless trolley",
917
+ "875": "trombone",
918
+ "876": "tub, vat",
919
+ "877": "turnstile",
920
+ "878": "typewriter keyboard",
921
+ "879": "umbrella",
922
+ "880": "unicycle, monocycle",
923
+ "881": "upright, upright piano",
924
+ "882": "vacuum, vacuum cleaner",
925
+ "883": "vase",
926
+ "884": "vault",
927
+ "885": "velvet",
928
+ "886": "vending machine",
929
+ "887": "vestment",
930
+ "888": "viaduct",
931
+ "889": "violin, fiddle",
932
+ "890": "volleyball",
933
+ "891": "waffle iron",
934
+ "892": "wall clock",
935
+ "893": "wallet, billfold, notecase, pocketbook",
936
+ "894": "wardrobe, closet, press",
937
+ "895": "warplane, military plane",
938
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
939
+ "897": "washer, automatic washer, washing machine",
940
+ "898": "water bottle",
941
+ "899": "water jug",
942
+ "900": "water tower",
943
+ "901": "whiskey jug",
944
+ "902": "whistle",
945
+ "903": "wig",
946
+ "904": "window screen",
947
+ "905": "window shade",
948
+ "906": "Windsor tie",
949
+ "907": "wine bottle",
950
+ "908": "wing",
951
+ "909": "wok",
952
+ "910": "wooden spoon",
953
+ "911": "wool, woolen, woollen",
954
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
955
+ "913": "wreck",
956
+ "914": "yawl",
957
+ "915": "yurt",
958
+ "916": "web site, website, internet site, site",
959
+ "917": "comic book",
960
+ "918": "crossword puzzle, crossword",
961
+ "919": "street sign",
962
+ "920": "traffic light, traffic signal, stoplight",
963
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
964
+ "922": "menu",
965
+ "923": "plate",
966
+ "924": "guacamole",
967
+ "925": "consomme",
968
+ "926": "hot pot, hotpot",
969
+ "927": "trifle",
970
+ "928": "ice cream, icecream",
971
+ "929": "ice lolly, lolly, lollipop, popsicle",
972
+ "930": "French loaf",
973
+ "931": "bagel, beigel",
974
+ "932": "pretzel",
975
+ "933": "cheeseburger",
976
+ "934": "hotdog, hot dog, red hot",
977
+ "935": "mashed potato",
978
+ "936": "head cabbage",
979
+ "937": "broccoli",
980
+ "938": "cauliflower",
981
+ "939": "zucchini, courgette",
982
+ "940": "spaghetti squash",
983
+ "941": "acorn squash",
984
+ "942": "butternut squash",
985
+ "943": "cucumber, cuke",
986
+ "944": "artichoke, globe artichoke",
987
+ "945": "bell pepper",
988
+ "946": "cardoon",
989
+ "947": "mushroom",
990
+ "948": "Granny Smith",
991
+ "949": "strawberry",
992
+ "950": "orange",
993
+ "951": "lemon",
994
+ "952": "fig",
995
+ "953": "pineapple, ananas",
996
+ "954": "banana",
997
+ "955": "jackfruit, jak, jack",
998
+ "956": "custard apple",
999
+ "957": "pomegranate",
1000
+ "958": "hay",
1001
+ "959": "carbonara",
1002
+ "960": "chocolate sauce, chocolate syrup",
1003
+ "961": "dough",
1004
+ "962": "meat loaf, meatloaf",
1005
+ "963": "pizza, pizza pie",
1006
+ "964": "potpie",
1007
+ "965": "burrito",
1008
+ "966": "red wine",
1009
+ "967": "espresso",
1010
+ "968": "cup",
1011
+ "969": "eggnog",
1012
+ "970": "alp",
1013
+ "971": "bubble",
1014
+ "972": "cliff, drop, drop-off",
1015
+ "973": "coral reef",
1016
+ "974": "geyser",
1017
+ "975": "lakeside, lakeshore",
1018
+ "976": "promontory, headland, head, foreland",
1019
+ "977": "sandbar, sand bar",
1020
+ "978": "seashore, coast, seacoast, sea-coast",
1021
+ "979": "valley, vale",
1022
+ "980": "volcano",
1023
+ "981": "ballplayer, baseball player",
1024
+ "982": "groom, bridegroom",
1025
+ "983": "scuba diver",
1026
+ "984": "rapeseed",
1027
+ "985": "daisy",
1028
+ "986": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1029
+ "987": "corn",
1030
+ "988": "acorn",
1031
+ "989": "hip, rose hip, rosehip",
1032
+ "990": "buckeye, horse chestnut, conker",
1033
+ "991": "coral fungus",
1034
+ "992": "agaric",
1035
+ "993": "gyromitra",
1036
+ "994": "stinkhorn, carrion fungus",
1037
+ "995": "earthstar",
1038
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1039
+ "997": "bolete",
1040
+ "998": "ear, spike, capitulum",
1041
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1042
+ },
1043
+ "image_size": 224,
1044
+ "initializer_range": 0.02,
1045
+ "is_decoder": false,
1046
+ "is_encoder_decoder": false,
1047
+ "label2id": {
1048
+ "Afghan hound, Afghan": 160,
1049
+ "African chameleon, Chamaeleo chamaeleon": 47,
1050
+ "African crocodile, Nile crocodile, Crocodylus niloticus": 49,
1051
+ "African elephant, Loxodonta africana": 386,
1052
+ "African grey, African gray, Psittacus erithacus": 87,
1053
+ "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus": 275,
1054
+ "Airedale, Airedale terrier": 191,
1055
+ "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier": 180,
1056
+ "American alligator, Alligator mississipiensis": 50,
1057
+ "American black bear, black bear, Ursus americanus, Euarctos americanus": 295,
1058
+ "American chameleon, anole, Anolis carolinensis": 40,
1059
+ "American coot, marsh hen, mud hen, water hen, Fulica americana": 137,
1060
+ "American egret, great white heron, Egretta albus": 132,
1061
+ "American lobster, Northern lobster, Maine lobster, Homarus americanus": 122,
1062
+ "Angora, Angora rabbit": 332,
1063
+ "Appenzeller": 240,
1064
+ "Arabian camel, dromedary, Camelus dromedarius": 354,
1065
+ "Arctic fox, white fox, Alopex lagopus": 279,
1066
+ "Australian terrier": 193,
1067
+ "Band Aid": 419,
1068
+ "Bedlington terrier": 181,
1069
+ "Bernese mountain dog": 239,
1070
+ "Blenheim spaniel": 156,
1071
+ "Border collie": 232,
1072
+ "Border terrier": 182,
1073
+ "Boston bull, Boston terrier": 195,
1074
+ "Bouvier des Flandres, Bouviers des Flandres": 233,
1075
+ "Brabancon griffon": 262,
1076
+ "Brittany spaniel": 215,
1077
+ "CD player": 485,
1078
+ "Cardigan, Cardigan Welsh corgi": 264,
1079
+ "Chesapeake Bay retriever": 209,
1080
+ "Chihuahua": 151,
1081
+ "Christmas stocking": 496,
1082
+ "Crock Pot": 521,
1083
+ "Dandie Dinmont, Dandie Dinmont terrier": 194,
1084
+ "Doberman, Doberman pinscher": 236,
1085
+ "Dungeness crab, Cancer magister": 118,
1086
+ "Dutch oven": 544,
1087
+ "Egyptian cat": 285,
1088
+ "English foxhound": 167,
1089
+ "English setter": 212,
1090
+ "English springer, English springer spaniel": 217,
1091
+ "EntleBucher": 241,
1092
+ "Eskimo dog, husky": 248,
1093
+ "European fire salamander, Salamandra salamandra": 25,
1094
+ "European gallinule, Porphyrio porphyrio": 136,
1095
+ "French bulldog": 245,
1096
+ "French horn, horn": 566,
1097
+ "French loaf": 930,
1098
+ "German shepherd, German shepherd dog, German police dog, alsatian": 235,
1099
+ "German short-haired pointer": 210,
1100
+ "Gila monster, Heloderma suspectum": 45,
1101
+ "Gordon setter": 214,
1102
+ "Granny Smith": 948,
1103
+ "Great Dane": 246,
1104
+ "Great Pyrenees": 257,
1105
+ "Greater Swiss Mountain dog": 238,
1106
+ "Ibizan hound, Ibizan Podenco": 173,
1107
+ "Indian cobra, Naja naja": 63,
1108
+ "Indian elephant, Elephas maximus": 385,
1109
+ "Irish setter, red setter": 213,
1110
+ "Irish terrier": 184,
1111
+ "Irish water spaniel": 221,
1112
+ "Irish wolfhound": 170,
1113
+ "Italian greyhound": 171,
1114
+ "Japanese spaniel": 152,
1115
+ "Kerry blue terrier": 183,
1116
+ "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis": 48,
1117
+ "Labrador retriever": 208,
1118
+ "Lakeland terrier": 189,
1119
+ "Leonberg": 255,
1120
+ "Lhasa, Lhasa apso": 204,
1121
+ "Loafer": 630,
1122
+ "Madagascar cat, ring-tailed lemur, Lemur catta": 383,
1123
+ "Maltese dog, Maltese terrier, Maltese": 153,
1124
+ "Mexican hairless": 268,
1125
+ "Model T": 661,
1126
+ "Newfoundland, Newfoundland dog": 256,
1127
+ "Norfolk terrier": 185,
1128
+ "Norwegian elkhound, elkhound": 174,
1129
+ "Norwich terrier": 186,
1130
+ "Old English sheepdog, bobtail": 229,
1131
+ "Pekinese, Pekingese, Peke": 154,
1132
+ "Pembroke, Pembroke Welsh corgi": 263,
1133
+ "Persian cat": 283,
1134
+ "Petri dish": 712,
1135
+ "Polaroid camera, Polaroid Land camera": 732,
1136
+ "Pomeranian": 259,
1137
+ "Rhodesian ridgeback": 159,
1138
+ "Rottweiler": 234,
1139
+ "Saint Bernard, St Bernard": 247,
1140
+ "Saluki, gazelle hound": 176,
1141
+ "Samoyed, Samoyede": 258,
1142
+ "Scotch terrier, Scottish terrier, Scottie": 199,
1143
+ "Scottish deerhound, deerhound": 177,
1144
+ "Sealyham terrier, Sealyham": 190,
1145
+ "Shetland sheepdog, Shetland sheep dog, Shetland": 230,
1146
+ "Shih-Tzu": 155,
1147
+ "Siamese cat, Siamese": 284,
1148
+ "Siberian husky": 250,
1149
+ "Staffordshire bullterrier, Staffordshire bull terrier": 179,
1150
+ "Sussex spaniel": 220,
1151
+ "Tibetan mastiff": 244,
1152
+ "Tibetan terrier, chrysanthemum dog": 200,
1153
+ "Walker hound, Walker foxhound": 166,
1154
+ "Weimaraner": 178,
1155
+ "Welsh springer spaniel": 218,
1156
+ "West Highland white terrier": 203,
1157
+ "Windsor tie": 906,
1158
+ "Yorkshire terrier": 187,
1159
+ "abacus": 398,
1160
+ "abaya": 399,
1161
+ "academic gown, academic robe, judge's robe": 400,
1162
+ "accordion, piano accordion, squeeze box": 401,
1163
+ "acorn": 988,
1164
+ "acorn squash": 941,
1165
+ "acoustic guitar": 402,
1166
+ "admiral": 321,
1167
+ "affenpinscher, monkey pinscher, monkey dog": 252,
1168
+ "agama": 42,
1169
+ "agaric": 992,
1170
+ "aircraft carrier, carrier, flattop, attack aircraft carrier": 403,
1171
+ "airliner": 404,
1172
+ "airship, dirigible": 405,
1173
+ "albatross, mollymawk": 146,
1174
+ "alligator lizard": 44,
1175
+ "alp": 970,
1176
+ "altar": 406,
1177
+ "ambulance": 407,
1178
+ "amphibian, amphibious vehicle": 408,
1179
+ "analog clock": 409,
1180
+ "anemone fish": 393,
1181
+ "ant, emmet, pismire": 310,
1182
+ "apiary, bee house": 410,
1183
+ "apron": 411,
1184
+ "armadillo": 363,
1185
+ "artichoke, globe artichoke": 944,
1186
+ "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin": 412,
1187
+ "assault rifle, assault gun": 413,
1188
+ "axolotl, mud puppy, Ambystoma mexicanum": 29,
1189
+ "baboon": 372,
1190
+ "backpack, back pack, knapsack, packsack, rucksack, haversack": 414,
1191
+ "badger": 362,
1192
+ "bagel, beigel": 931,
1193
+ "bakery, bakeshop, bakehouse": 415,
1194
+ "balance beam, beam": 416,
1195
+ "bald eagle, American eagle, Haliaeetus leucocephalus": 22,
1196
+ "balloon": 417,
1197
+ "ballplayer, baseball player": 981,
1198
+ "ballpoint, ballpoint pen, ballpen, Biro": 418,
1199
+ "banana": 954,
1200
+ "banded gecko": 38,
1201
+ "banjo": 420,
1202
+ "bannister, banister, balustrade, balusters, handrail": 421,
1203
+ "barbell": 422,
1204
+ "barber chair": 423,
1205
+ "barbershop": 424,
1206
+ "barn": 425,
1207
+ "barn spider, Araneus cavaticus": 73,
1208
+ "barometer": 426,
1209
+ "barracouta, snoek": 389,
1210
+ "barrel, cask": 427,
1211
+ "barrow, garden cart, lawn cart, wheelbarrow": 428,
1212
+ "baseball": 429,
1213
+ "basenji": 253,
1214
+ "basketball": 430,
1215
+ "basset, basset hound": 161,
1216
+ "bassinet": 431,
1217
+ "bassoon": 432,
1218
+ "bath towel": 434,
1219
+ "bathing cap, swimming cap": 433,
1220
+ "bathtub, bathing tub, bath, tub": 435,
1221
+ "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon": 436,
1222
+ "beacon, lighthouse, beacon light, pharos": 437,
1223
+ "beagle": 162,
1224
+ "beaker": 438,
1225
+ "bearskin, busby, shako": 439,
1226
+ "beaver": 337,
1227
+ "bee": 309,
1228
+ "bee eater": 92,
1229
+ "beer bottle": 440,
1230
+ "beer glass": 441,
1231
+ "bell cote, bell cot": 442,
1232
+ "bell pepper": 945,
1233
+ "bib": 443,
1234
+ "bicycle-built-for-two, tandem bicycle, tandem": 444,
1235
+ "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis": 349,
1236
+ "bikini, two-piece": 445,
1237
+ "binder, ring-binder": 446,
1238
+ "binoculars, field glasses, opera glasses": 447,
1239
+ "birdhouse": 448,
1240
+ "bison": 347,
1241
+ "bittern": 133,
1242
+ "black and gold garden spider, Argiope aurantia": 72,
1243
+ "black grouse": 80,
1244
+ "black stork, Ciconia nigra": 128,
1245
+ "black swan, Cygnus atratus": 100,
1246
+ "black widow, Latrodectus mactans": 75,
1247
+ "black-and-tan coonhound": 165,
1248
+ "black-footed ferret, ferret, Mustela nigripes": 359,
1249
+ "bloodhound, sleuthhound": 163,
1250
+ "bluetick": 164,
1251
+ "boa constrictor, Constrictor constrictor": 61,
1252
+ "boathouse": 449,
1253
+ "bobsled, bobsleigh, bob": 450,
1254
+ "bolete": 997,
1255
+ "bolo tie, bolo, bola tie, bola": 451,
1256
+ "bonnet, poke bonnet": 452,
1257
+ "book jacket, dust cover, dust jacket, dust wrapper": 921,
1258
+ "bookcase": 453,
1259
+ "bookshop, bookstore, bookstall": 454,
1260
+ "borzoi, Russian wolfhound": 169,
1261
+ "bottlecap": 455,
1262
+ "bow": 456,
1263
+ "bow tie, bow-tie, bowtie": 457,
1264
+ "box turtle, box tortoise": 37,
1265
+ "boxer": 242,
1266
+ "brain coral": 109,
1267
+ "brambling, Fringilla montifringilla": 10,
1268
+ "brass, memorial tablet, plaque": 458,
1269
+ "brassiere, bra, bandeau": 459,
1270
+ "breakwater, groin, groyne, mole, bulwark, seawall, jetty": 460,
1271
+ "breastplate, aegis, egis": 461,
1272
+ "briard": 226,
1273
+ "broccoli": 937,
1274
+ "broom": 462,
1275
+ "brown bear, bruin, Ursus arctos": 294,
1276
+ "bubble": 971,
1277
+ "bucket, pail": 463,
1278
+ "buckeye, horse chestnut, conker": 990,
1279
+ "buckle": 464,
1280
+ "bulbul": 16,
1281
+ "bull mastiff": 243,
1282
+ "bullet train, bullet": 466,
1283
+ "bulletproof vest": 465,
1284
+ "bullfrog, Rana catesbeiana": 30,
1285
+ "burrito": 965,
1286
+ "bustard": 138,
1287
+ "butcher shop, meat market": 467,
1288
+ "butternut squash": 942,
1289
+ "cab, hack, taxi, taxicab": 468,
1290
+ "cabbage butterfly": 324,
1291
+ "cairn, cairn terrier": 192,
1292
+ "caldron, cauldron": 469,
1293
+ "can opener, tin opener": 473,
1294
+ "candle, taper, wax light": 470,
1295
+ "cannon": 471,
1296
+ "canoe": 472,
1297
+ "capuchin, ringtail, Cebus capucinus": 378,
1298
+ "car mirror": 475,
1299
+ "car wheel": 479,
1300
+ "carbonara": 959,
1301
+ "cardigan": 474,
1302
+ "cardoon": 946,
1303
+ "carousel, carrousel, merry-go-round, roundabout, whirligig": 476,
1304
+ "carpenter's kit, tool kit": 477,
1305
+ "carton": 478,
1306
+ "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM": 480,
1307
+ "cassette": 481,
1308
+ "cassette player": 482,
1309
+ "castle": 483,
1310
+ "catamaran": 484,
1311
+ "cauliflower": 938,
1312
+ "cello, violoncello": 486,
1313
+ "cellular telephone, cellular phone, cellphone, cell, mobile phone": 487,
1314
+ "centipede": 79,
1315
+ "chain": 488,
1316
+ "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour": 490,
1317
+ "chain saw, chainsaw": 491,
1318
+ "chainlink fence": 489,
1319
+ "chambered nautilus, pearly nautilus, nautilus": 117,
1320
+ "cheeseburger": 933,
1321
+ "cheetah, chetah, Acinonyx jubatus": 293,
1322
+ "chest": 492,
1323
+ "chickadee": 19,
1324
+ "chiffonier, commode": 493,
1325
+ "chime, bell, gong": 494,
1326
+ "chimpanzee, chimp, Pan troglodytes": 367,
1327
+ "china cabinet, china closet": 495,
1328
+ "chiton, coat-of-mail shell, sea cradle, polyplacophore": 116,
1329
+ "chocolate sauce, chocolate syrup": 960,
1330
+ "chow, chow chow": 260,
1331
+ "church, church building": 497,
1332
+ "cicada, cicala": 316,
1333
+ "cinema, movie theater, movie theatre, movie house, picture palace": 498,
1334
+ "cleaver, meat cleaver, chopper": 499,
1335
+ "cliff dwelling": 500,
1336
+ "cliff, drop, drop-off": 972,
1337
+ "cloak": 501,
1338
+ "clog, geta, patten, sabot": 502,
1339
+ "clumber, clumber spaniel": 216,
1340
+ "cock": 7,
1341
+ "cocker spaniel, English cocker spaniel, cocker": 219,
1342
+ "cockroach, roach": 314,
1343
+ "cocktail shaker": 503,
1344
+ "coffee mug": 504,
1345
+ "coffeepot": 505,
1346
+ "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch": 391,
1347
+ "coil, spiral, volute, whorl, helix": 506,
1348
+ "collie": 231,
1349
+ "colobus, colobus monkey": 375,
1350
+ "combination lock": 507,
1351
+ "comic book": 917,
1352
+ "common iguana, iguana, Iguana iguana": 39,
1353
+ "common newt, Triturus vulgaris": 26,
1354
+ "computer keyboard, keypad": 508,
1355
+ "conch": 112,
1356
+ "confectionery, confectionary, candy store": 509,
1357
+ "consomme": 925,
1358
+ "container ship, containership, container vessel": 510,
1359
+ "convertible": 511,
1360
+ "coral fungus": 991,
1361
+ "coral reef": 973,
1362
+ "corkscrew, bottle screw": 512,
1363
+ "corn": 987,
1364
+ "cornet, horn, trumpet, trump": 513,
1365
+ "coucal": 91,
1366
+ "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor": 286,
1367
+ "cowboy boot": 514,
1368
+ "cowboy hat, ten-gallon hat": 515,
1369
+ "coyote, prairie wolf, brush wolf, Canis latrans": 272,
1370
+ "cradle": 516,
1371
+ "crane": 517,
1372
+ "crash helmet": 518,
1373
+ "crate": 519,
1374
+ "crayfish, crawfish, crawdad, crawdaddy": 124,
1375
+ "crib, cot": 520,
1376
+ "cricket": 312,
1377
+ "croquet ball": 522,
1378
+ "crossword puzzle, crossword": 918,
1379
+ "crutch": 523,
1380
+ "cucumber, cuke": 943,
1381
+ "cuirass": 524,
1382
+ "cup": 968,
1383
+ "curly-coated retriever": 206,
1384
+ "custard apple": 956,
1385
+ "daisy": 985,
1386
+ "dalmatian, coach dog, carriage dog": 251,
1387
+ "dam, dike, dyke": 525,
1388
+ "damselfly": 320,
1389
+ "desk": 526,
1390
+ "desktop computer": 527,
1391
+ "dhole, Cuon alpinus": 274,
1392
+ "dial telephone, dial phone": 528,
1393
+ "diamondback, diamondback rattlesnake, Crotalus adamanteus": 67,
1394
+ "diaper, nappy, napkin": 529,
1395
+ "digital clock": 530,
1396
+ "digital watch": 531,
1397
+ "dingo, warrigal, warragal, Canis dingo": 273,
1398
+ "dining table, board": 532,
1399
+ "dishrag, dishcloth": 533,
1400
+ "dishwasher, dish washer, dishwashing machine": 534,
1401
+ "disk brake, disc brake": 535,
1402
+ "dock, dockage, docking facility": 536,
1403
+ "dogsled, dog sled, dog sleigh": 537,
1404
+ "dome": 538,
1405
+ "doormat, welcome mat": 539,
1406
+ "dough": 961,
1407
+ "dowitcher": 142,
1408
+ "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk": 319,
1409
+ "drake": 97,
1410
+ "drilling platform, offshore rig": 540,
1411
+ "drum, membranophone, tympan": 541,
1412
+ "drumstick": 542,
1413
+ "dugong, Dugong dugon": 149,
1414
+ "dumbbell": 543,
1415
+ "dung beetle": 305,
1416
+ "ear, spike, capitulum": 998,
1417
+ "earthstar": 995,
1418
+ "echidna, spiny anteater, anteater": 102,
1419
+ "eel": 390,
1420
+ "eft": 27,
1421
+ "eggnog": 969,
1422
+ "electric fan, blower": 545,
1423
+ "electric guitar": 546,
1424
+ "electric locomotive": 547,
1425
+ "electric ray, crampfish, numbfish, torpedo": 5,
1426
+ "entertainment center": 548,
1427
+ "envelope": 549,
1428
+ "espresso": 967,
1429
+ "espresso maker": 550,
1430
+ "face powder": 551,
1431
+ "feather boa, boa": 552,
1432
+ "fiddler crab": 120,
1433
+ "fig": 952,
1434
+ "file, file cabinet, filing cabinet": 553,
1435
+ "fire engine, fire truck": 555,
1436
+ "fire screen, fireguard": 556,
1437
+ "fireboat": 554,
1438
+ "flagpole, flagstaff": 557,
1439
+ "flamingo": 130,
1440
+ "flat-coated retriever": 205,
1441
+ "flatworm, platyhelminth": 110,
1442
+ "flute, transverse flute": 558,
1443
+ "fly": 308,
1444
+ "folding chair": 559,
1445
+ "football helmet": 560,
1446
+ "forklift": 561,
1447
+ "fountain": 562,
1448
+ "fountain pen": 563,
1449
+ "four-poster": 564,
1450
+ "fox squirrel, eastern fox squirrel, Sciurus niger": 335,
1451
+ "freight car": 565,
1452
+ "frilled lizard, Chlamydosaurus kingi": 43,
1453
+ "frying pan, frypan, skillet": 567,
1454
+ "fur coat": 568,
1455
+ "gar, garfish, garpike, billfish, Lepisosteus osseus": 395,
1456
+ "garbage truck, dustcart": 569,
1457
+ "garden spider, Aranea diademata": 74,
1458
+ "garter snake, grass snake": 57,
1459
+ "gas pump, gasoline pump, petrol pump, island dispenser": 571,
1460
+ "gasmask, respirator, gas helmet": 570,
1461
+ "gazelle": 353,
1462
+ "geyser": 974,
1463
+ "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca": 388,
1464
+ "giant schnauzer": 197,
1465
+ "gibbon, Hylobates lar": 368,
1466
+ "go-kart": 573,
1467
+ "goblet": 572,
1468
+ "golden retriever": 207,
1469
+ "goldfinch, Carduelis carduelis": 11,
1470
+ "goldfish, Carassius auratus": 1,
1471
+ "golf ball": 574,
1472
+ "golfcart, golf cart": 575,
1473
+ "gondola": 576,
1474
+ "gong, tam-tam": 577,
1475
+ "goose": 99,
1476
+ "gorilla, Gorilla gorilla": 366,
1477
+ "gown": 578,
1478
+ "grand piano, grand": 579,
1479
+ "grasshopper, hopper": 311,
1480
+ "great grey owl, great gray owl, Strix nebulosa": 24,
1481
+ "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias": 2,
1482
+ "green lizard, Lacerta viridis": 46,
1483
+ "green mamba": 64,
1484
+ "green snake, grass snake": 55,
1485
+ "greenhouse, nursery, glasshouse": 580,
1486
+ "grey fox, gray fox, Urocyon cinereoargenteus": 280,
1487
+ "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus": 147,
1488
+ "grille, radiator grille": 581,
1489
+ "grocery store, grocery, food market, market": 582,
1490
+ "groenendael": 224,
1491
+ "groom, bridegroom": 982,
1492
+ "ground beetle, carabid beetle": 302,
1493
+ "guacamole": 924,
1494
+ "guenon, guenon monkey": 370,
1495
+ "guillotine": 583,
1496
+ "guinea pig, Cavia cobaya": 338,
1497
+ "gyromitra": 993,
1498
+ "hair slide": 584,
1499
+ "hair spray": 585,
1500
+ "half track": 586,
1501
+ "hammer": 587,
1502
+ "hammerhead, hammerhead shark": 4,
1503
+ "hamper": 588,
1504
+ "hamster": 333,
1505
+ "hand blower, blow dryer, blow drier, hair dryer, hair drier": 589,
1506
+ "hand-held computer, hand-held microcomputer": 590,
1507
+ "handkerchief, hankie, hanky, hankey": 591,
1508
+ "hard disc, hard disk, fixed disk": 592,
1509
+ "hare": 331,
1510
+ "harmonica, mouth organ, harp, mouth harp": 593,
1511
+ "harp": 594,
1512
+ "hartebeest": 351,
1513
+ "harvester, reaper": 595,
1514
+ "harvestman, daddy longlegs, Phalangium opilio": 70,
1515
+ "hatchet": 596,
1516
+ "hay": 958,
1517
+ "head cabbage": 936,
1518
+ "hen": 8,
1519
+ "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa": 996,
1520
+ "hermit crab": 125,
1521
+ "hip, rose hip, rosehip": 989,
1522
+ "hippopotamus, hippo, river horse, Hippopotamus amphibius": 344,
1523
+ "hog, pig, grunter, squealer, Sus scrofa": 341,
1524
+ "hognose snake, puff adder, sand viper": 54,
1525
+ "holster": 597,
1526
+ "home theater, home theatre": 598,
1527
+ "honeycomb": 599,
1528
+ "hook, claw": 600,
1529
+ "hoopskirt, crinoline": 601,
1530
+ "horizontal bar, high bar": 602,
1531
+ "hornbill": 93,
1532
+ "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus": 66,
1533
+ "horse cart, horse-cart": 603,
1534
+ "hot pot, hotpot": 926,
1535
+ "hotdog, hot dog, red hot": 934,
1536
+ "hourglass": 604,
1537
+ "house finch, linnet, Carpodacus mexicanus": 12,
1538
+ "howler monkey, howler": 379,
1539
+ "hummingbird": 94,
1540
+ "hyena, hyaena": 276,
1541
+ "iPod": 605,
1542
+ "ibex, Capra ibex": 350,
1543
+ "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus": 296,
1544
+ "ice cream, icecream": 928,
1545
+ "ice lolly, lolly, lollipop, popsicle": 929,
1546
+ "impala, Aepyceros melampus": 352,
1547
+ "indigo bunting, indigo finch, indigo bird, Passerina cyanea": 14,
1548
+ "indri, indris, Indri indri, Indri brevicaudatus": 384,
1549
+ "iron, smoothing iron": 606,
1550
+ "isopod": 126,
1551
+ "jacamar": 95,
1552
+ "jack-o'-lantern": 607,
1553
+ "jackfruit, jak, jack": 955,
1554
+ "jaguar, panther, Panthera onca, Felis onca": 290,
1555
+ "jay": 17,
1556
+ "jean, blue jean, denim": 608,
1557
+ "jeep, landrover": 609,
1558
+ "jellyfish": 107,
1559
+ "jersey, T-shirt, tee shirt": 610,
1560
+ "jigsaw puzzle": 611,
1561
+ "jinrikisha, ricksha, rickshaw": 612,
1562
+ "joystick": 613,
1563
+ "junco, snowbird": 13,
1564
+ "keeshond": 261,
1565
+ "kelpie": 227,
1566
+ "killer whale, killer, orca, grampus, sea wolf, Orcinus orca": 148,
1567
+ "kimono": 614,
1568
+ "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica": 121,
1569
+ "king penguin, Aptenodytes patagonica": 145,
1570
+ "king snake, kingsnake": 56,
1571
+ "kit fox, Vulpes macrotis": 278,
1572
+ "kite": 21,
1573
+ "knee pad": 615,
1574
+ "knot": 616,
1575
+ "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus": 105,
1576
+ "komondor": 228,
1577
+ "kuvasz": 222,
1578
+ "lab coat, laboratory coat": 617,
1579
+ "lacewing, lacewing fly": 318,
1580
+ "ladle": 618,
1581
+ "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle": 301,
1582
+ "lakeside, lakeshore": 975,
1583
+ "lampshade, lamp shade": 619,
1584
+ "langur": 374,
1585
+ "laptop, laptop computer": 620,
1586
+ "lawn mower, mower": 621,
1587
+ "leaf beetle, chrysomelid": 304,
1588
+ "leafhopper": 317,
1589
+ "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea": 34,
1590
+ "lemon": 951,
1591
+ "lens cap, lens cover": 622,
1592
+ "leopard, Panthera pardus": 288,
1593
+ "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens": 387,
1594
+ "letter opener, paper knife, paperknife": 623,
1595
+ "library": 624,
1596
+ "lifeboat": 625,
1597
+ "lighter, light, igniter, ignitor": 626,
1598
+ "limousine, limo": 627,
1599
+ "limpkin, Aramus pictus": 135,
1600
+ "liner, ocean liner": 628,
1601
+ "lion, king of beasts, Panthera leo": 291,
1602
+ "lionfish": 396,
1603
+ "lipstick, lip rouge": 629,
1604
+ "little blue heron, Egretta caerulea": 131,
1605
+ "llama": 355,
1606
+ "loggerhead, loggerhead turtle, Caretta caretta": 33,
1607
+ "long-horned beetle, longicorn, longicorn beetle": 303,
1608
+ "lorikeet": 90,
1609
+ "lotion": 631,
1610
+ "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system": 632,
1611
+ "loupe, jeweler's loupe": 633,
1612
+ "lumbermill, sawmill": 634,
1613
+ "lycaenid, lycaenid butterfly": 326,
1614
+ "lynx, catamount": 287,
1615
+ "macaque": 373,
1616
+ "macaw": 88,
1617
+ "magnetic compass": 635,
1618
+ "magpie": 18,
1619
+ "mailbag, postbag": 636,
1620
+ "mailbox, letter box": 637,
1621
+ "maillot": 638,
1622
+ "maillot, tank suit": 639,
1623
+ "malamute, malemute, Alaskan malamute": 249,
1624
+ "malinois": 225,
1625
+ "manhole cover": 640,
1626
+ "mantis, mantid": 315,
1627
+ "maraca": 641,
1628
+ "marimba, xylophone": 642,
1629
+ "marmoset": 377,
1630
+ "marmot": 336,
1631
+ "mashed potato": 935,
1632
+ "mask": 643,
1633
+ "matchstick": 644,
1634
+ "maypole": 645,
1635
+ "maze, labyrinth": 646,
1636
+ "measuring cup": 647,
1637
+ "meat loaf, meatloaf": 962,
1638
+ "medicine chest, medicine cabinet": 648,
1639
+ "meerkat, mierkat": 299,
1640
+ "megalith, megalithic structure": 649,
1641
+ "menu": 922,
1642
+ "microphone, mike": 650,
1643
+ "microwave, microwave oven": 651,
1644
+ "military uniform": 652,
1645
+ "milk can": 653,
1646
+ "miniature pinscher": 237,
1647
+ "miniature poodle": 266,
1648
+ "miniature schnauzer": 196,
1649
+ "minibus": 654,
1650
+ "miniskirt, mini": 655,
1651
+ "minivan": 656,
1652
+ "mink": 357,
1653
+ "missile": 657,
1654
+ "mitten": 658,
1655
+ "mixing bowl": 659,
1656
+ "mobile home, manufactured home": 660,
1657
+ "modem": 662,
1658
+ "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus": 323,
1659
+ "monastery": 663,
1660
+ "mongoose": 298,
1661
+ "monitor": 664,
1662
+ "moped": 665,
1663
+ "mortar": 666,
1664
+ "mortarboard": 667,
1665
+ "mosque": 668,
1666
+ "mosquito net": 669,
1667
+ "motor scooter, scooter": 670,
1668
+ "mountain bike, all-terrain bike, off-roader": 671,
1669
+ "mountain tent": 672,
1670
+ "mouse, computer mouse": 673,
1671
+ "mousetrap": 674,
1672
+ "moving van": 675,
1673
+ "mud turtle": 35,
1674
+ "mushroom": 947,
1675
+ "muzzle": 676,
1676
+ "nail": 677,
1677
+ "neck brace": 678,
1678
+ "necklace": 679,
1679
+ "nematode, nematode worm, roundworm": 111,
1680
+ "night snake, Hypsiglena torquata": 60,
1681
+ "nipple": 680,
1682
+ "notebook, notebook computer": 681,
1683
+ "obelisk": 682,
1684
+ "oboe, hautboy, hautbois": 683,
1685
+ "ocarina, sweet potato": 684,
1686
+ "odometer, hodometer, mileometer, milometer": 685,
1687
+ "oil filter": 686,
1688
+ "orange": 950,
1689
+ "orangutan, orang, orangutang, Pongo pygmaeus": 365,
1690
+ "organ, pipe organ": 687,
1691
+ "oscilloscope, scope, cathode-ray oscilloscope, CRO": 688,
1692
+ "ostrich, Struthio camelus": 9,
1693
+ "otter": 360,
1694
+ "otterhound, otter hound": 175,
1695
+ "overskirt": 689,
1696
+ "ox": 345,
1697
+ "oxcart": 690,
1698
+ "oxygen mask": 691,
1699
+ "oystercatcher, oyster catcher": 143,
1700
+ "packet": 692,
1701
+ "paddle, boat paddle": 693,
1702
+ "paddlewheel, paddle wheel": 694,
1703
+ "padlock": 695,
1704
+ "paintbrush": 696,
1705
+ "pajama, pyjama, pj's, jammies": 697,
1706
+ "palace": 698,
1707
+ "panpipe, pandean pipe, syrinx": 699,
1708
+ "paper towel": 700,
1709
+ "papillon": 157,
1710
+ "parachute, chute": 701,
1711
+ "parallel bars, bars": 702,
1712
+ "park bench": 703,
1713
+ "parking meter": 704,
1714
+ "partridge": 86,
1715
+ "passenger car, coach, carriage": 705,
1716
+ "patas, hussar monkey, Erythrocebus patas": 371,
1717
+ "patio, terrace": 706,
1718
+ "pay-phone, pay-station": 707,
1719
+ "peacock": 84,
1720
+ "pedestal, plinth, footstall": 708,
1721
+ "pelican": 144,
1722
+ "pencil box, pencil case": 709,
1723
+ "pencil sharpener": 710,
1724
+ "perfume, essence": 711,
1725
+ "photocopier": 713,
1726
+ "pick, plectrum, plectron": 714,
1727
+ "pickelhaube": 715,
1728
+ "picket fence, paling": 716,
1729
+ "pickup, pickup truck": 717,
1730
+ "pier": 718,
1731
+ "piggy bank, penny bank": 719,
1732
+ "pill bottle": 720,
1733
+ "pillow": 721,
1734
+ "pineapple, ananas": 953,
1735
+ "ping-pong ball": 722,
1736
+ "pinwheel": 723,
1737
+ "pirate, pirate ship": 724,
1738
+ "pitcher, ewer": 725,
1739
+ "pizza, pizza pie": 963,
1740
+ "plane, carpenter's plane, woodworking plane": 726,
1741
+ "planetarium": 727,
1742
+ "plastic bag": 728,
1743
+ "plate": 923,
1744
+ "plate rack": 729,
1745
+ "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus": 103,
1746
+ "plow, plough": 730,
1747
+ "plunger, plumber's helper": 731,
1748
+ "pole": 733,
1749
+ "polecat, fitch, foulmart, foumart, Mustela putorius": 358,
1750
+ "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria": 734,
1751
+ "pomegranate": 957,
1752
+ "poncho": 735,
1753
+ "pool table, billiard table, snooker table": 736,
1754
+ "pop bottle, soda bottle": 737,
1755
+ "porcupine, hedgehog": 334,
1756
+ "pot, flowerpot": 738,
1757
+ "potpie": 964,
1758
+ "potter's wheel": 739,
1759
+ "power drill": 740,
1760
+ "prairie chicken, prairie grouse, prairie fowl": 83,
1761
+ "prayer rug, prayer mat": 741,
1762
+ "pretzel": 932,
1763
+ "printer": 742,
1764
+ "prison, prison house": 743,
1765
+ "proboscis monkey, Nasalis larvatus": 376,
1766
+ "projectile, missile": 744,
1767
+ "projector": 745,
1768
+ "promontory, headland, head, foreland": 976,
1769
+ "ptarmigan": 81,
1770
+ "puck, hockey puck": 746,
1771
+ "puffer, pufferfish, blowfish, globefish": 397,
1772
+ "pug, pug-dog": 254,
1773
+ "punching bag, punch bag, punching ball, punchball": 747,
1774
+ "purse": 748,
1775
+ "quail": 85,
1776
+ "quill, quill pen": 749,
1777
+ "quilt, comforter, comfort, puff": 750,
1778
+ "racer, race car, racing car": 751,
1779
+ "racket, racquet": 752,
1780
+ "radiator": 753,
1781
+ "radio telescope, radio reflector": 755,
1782
+ "radio, wireless": 754,
1783
+ "rain barrel": 756,
1784
+ "ram, tup": 348,
1785
+ "rapeseed": 984,
1786
+ "recreational vehicle, RV, R.V.": 757,
1787
+ "red fox, Vulpes vulpes": 277,
1788
+ "red wine": 966,
1789
+ "red wolf, maned wolf, Canis rufus, Canis niger": 271,
1790
+ "red-backed sandpiper, dunlin, Erolia alpina": 140,
1791
+ "red-breasted merganser, Mergus serrator": 98,
1792
+ "redbone": 168,
1793
+ "redshank, Tringa totanus": 141,
1794
+ "reel": 758,
1795
+ "reflex camera": 759,
1796
+ "refrigerator, icebox": 760,
1797
+ "remote control, remote": 761,
1798
+ "restaurant, eating house, eating place, eatery": 762,
1799
+ "revolver, six-gun, six-shooter": 763,
1800
+ "rhinoceros beetle": 306,
1801
+ "rifle": 764,
1802
+ "ringlet, ringlet butterfly": 322,
1803
+ "ringneck snake, ring-necked snake, ring snake": 53,
1804
+ "robin, American robin, Turdus migratorius": 15,
1805
+ "rock beauty, Holocanthus tricolor": 392,
1806
+ "rock crab, Cancer irroratus": 119,
1807
+ "rock python, rock snake, Python sebae": 62,
1808
+ "rocking chair, rocker": 765,
1809
+ "rotisserie": 766,
1810
+ "rubber eraser, rubber, pencil eraser": 767,
1811
+ "ruddy turnstone, Arenaria interpres": 139,
1812
+ "ruffed grouse, partridge, Bonasa umbellus": 82,
1813
+ "rugby ball": 768,
1814
+ "rule, ruler": 769,
1815
+ "running shoe": 770,
1816
+ "safe": 771,
1817
+ "safety pin": 772,
1818
+ "saltshaker, salt shaker": 773,
1819
+ "sandal": 774,
1820
+ "sandbar, sand bar": 977,
1821
+ "sarong": 775,
1822
+ "sax, saxophone": 776,
1823
+ "scabbard": 777,
1824
+ "scale, weighing machine": 778,
1825
+ "schipperke": 223,
1826
+ "school bus": 779,
1827
+ "schooner": 780,
1828
+ "scoreboard": 781,
1829
+ "scorpion": 71,
1830
+ "screen, CRT screen": 782,
1831
+ "screw": 783,
1832
+ "screwdriver": 784,
1833
+ "scuba diver": 983,
1834
+ "sea anemone, anemone": 108,
1835
+ "sea cucumber, holothurian": 329,
1836
+ "sea lion": 150,
1837
+ "sea slug, nudibranch": 115,
1838
+ "sea snake": 65,
1839
+ "sea urchin": 328,
1840
+ "seashore, coast, seacoast, sea-coast": 978,
1841
+ "seat belt, seatbelt": 785,
1842
+ "sewing machine": 786,
1843
+ "shield, buckler": 787,
1844
+ "shoe shop, shoe-shop, shoe store": 788,
1845
+ "shoji": 789,
1846
+ "shopping basket": 790,
1847
+ "shopping cart": 791,
1848
+ "shovel": 792,
1849
+ "shower cap": 793,
1850
+ "shower curtain": 794,
1851
+ "siamang, Hylobates syndactylus, Symphalangus syndactylus": 369,
1852
+ "sidewinder, horned rattlesnake, Crotalus cerastes": 68,
1853
+ "silky terrier, Sydney silky": 201,
1854
+ "ski": 795,
1855
+ "ski mask": 796,
1856
+ "skunk, polecat, wood pussy": 361,
1857
+ "sleeping bag": 797,
1858
+ "slide rule, slipstick": 798,
1859
+ "sliding door": 799,
1860
+ "slot, one-armed bandit": 800,
1861
+ "sloth bear, Melursus ursinus, Ursus ursinus": 297,
1862
+ "slug": 114,
1863
+ "snail": 113,
1864
+ "snorkel": 801,
1865
+ "snow leopard, ounce, Panthera uncia": 289,
1866
+ "snowmobile": 802,
1867
+ "snowplow, snowplough": 803,
1868
+ "soap dispenser": 804,
1869
+ "soccer ball": 805,
1870
+ "sock": 806,
1871
+ "soft-coated wheaten terrier": 202,
1872
+ "solar dish, solar collector, solar furnace": 807,
1873
+ "sombrero": 808,
1874
+ "sorrel": 339,
1875
+ "soup bowl": 809,
1876
+ "space bar": 810,
1877
+ "space heater": 811,
1878
+ "space shuttle": 812,
1879
+ "spaghetti squash": 940,
1880
+ "spatula": 813,
1881
+ "speedboat": 814,
1882
+ "spider monkey, Ateles geoffroyi": 381,
1883
+ "spider web, spider's web": 815,
1884
+ "spindle": 816,
1885
+ "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish": 123,
1886
+ "spoonbill": 129,
1887
+ "sports car, sport car": 817,
1888
+ "spotlight, spot": 818,
1889
+ "spotted salamander, Ambystoma maculatum": 28,
1890
+ "squirrel monkey, Saimiri sciureus": 382,
1891
+ "stage": 819,
1892
+ "standard poodle": 267,
1893
+ "standard schnauzer": 198,
1894
+ "starfish, sea star": 327,
1895
+ "steam locomotive": 820,
1896
+ "steel arch bridge": 821,
1897
+ "steel drum": 822,
1898
+ "stethoscope": 823,
1899
+ "stingray": 6,
1900
+ "stinkhorn, carrion fungus": 994,
1901
+ "stole": 824,
1902
+ "stone wall": 825,
1903
+ "stopwatch, stop watch": 826,
1904
+ "stove": 827,
1905
+ "strainer": 828,
1906
+ "strawberry": 949,
1907
+ "street sign": 919,
1908
+ "streetcar, tram, tramcar, trolley, trolley car": 829,
1909
+ "stretcher": 830,
1910
+ "studio couch, day bed": 831,
1911
+ "stupa, tope": 832,
1912
+ "sturgeon": 394,
1913
+ "submarine, pigboat, sub, U-boat": 833,
1914
+ "suit, suit of clothes": 834,
1915
+ "sulphur butterfly, sulfur butterfly": 325,
1916
+ "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita": 89,
1917
+ "sundial": 835,
1918
+ "sunglass": 836,
1919
+ "sunglasses, dark glasses, shades": 837,
1920
+ "sunscreen, sunblock, sun blocker": 838,
1921
+ "suspension bridge": 839,
1922
+ "swab, swob, mop": 840,
1923
+ "sweatshirt": 841,
1924
+ "swimming trunks, bathing trunks": 842,
1925
+ "swing": 843,
1926
+ "switch, electric switch, electrical switch": 844,
1927
+ "syringe": 845,
1928
+ "tabby, tabby cat": 281,
1929
+ "table lamp": 846,
1930
+ "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui": 32,
1931
+ "tank, army tank, armored combat vehicle, armoured combat vehicle": 847,
1932
+ "tape player": 848,
1933
+ "tarantula": 76,
1934
+ "teapot": 849,
1935
+ "teddy, teddy bear": 850,
1936
+ "television, television system": 851,
1937
+ "tench, Tinca tinca": 0,
1938
+ "tennis ball": 852,
1939
+ "terrapin": 36,
1940
+ "thatch, thatched roof": 853,
1941
+ "theater curtain, theatre curtain": 854,
1942
+ "thimble": 855,
1943
+ "three-toed sloth, ai, Bradypus tridactylus": 364,
1944
+ "thresher, thrasher, threshing machine": 856,
1945
+ "throne": 857,
1946
+ "thunder snake, worm snake, Carphophis amoenus": 52,
1947
+ "tick": 78,
1948
+ "tiger beetle": 300,
1949
+ "tiger cat": 282,
1950
+ "tiger shark, Galeocerdo cuvieri": 3,
1951
+ "tiger, Panthera tigris": 292,
1952
+ "tile roof": 858,
1953
+ "timber wolf, grey wolf, gray wolf, Canis lupus": 269,
1954
+ "titi, titi monkey": 380,
1955
+ "toaster": 859,
1956
+ "tobacco shop, tobacconist shop, tobacconist": 860,
1957
+ "toilet seat": 861,
1958
+ "toilet tissue, toilet paper, bathroom tissue": 999,
1959
+ "torch": 862,
1960
+ "totem pole": 863,
1961
+ "toucan": 96,
1962
+ "tow truck, tow car, wrecker": 864,
1963
+ "toy poodle": 265,
1964
+ "toy terrier": 158,
1965
+ "toyshop": 865,
1966
+ "tractor": 866,
1967
+ "traffic light, traffic signal, stoplight": 920,
1968
+ "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi": 867,
1969
+ "tray": 868,
1970
+ "tree frog, tree-frog": 31,
1971
+ "trench coat": 869,
1972
+ "triceratops": 51,
1973
+ "tricycle, trike, velocipede": 870,
1974
+ "trifle": 927,
1975
+ "trilobite": 69,
1976
+ "trimaran": 871,
1977
+ "tripod": 872,
1978
+ "triumphal arch": 873,
1979
+ "trolleybus, trolley coach, trackless trolley": 874,
1980
+ "trombone": 875,
1981
+ "tub, vat": 876,
1982
+ "turnstile": 877,
1983
+ "tusker": 101,
1984
+ "typewriter keyboard": 878,
1985
+ "umbrella": 879,
1986
+ "unicycle, monocycle": 880,
1987
+ "upright, upright piano": 881,
1988
+ "vacuum, vacuum cleaner": 882,
1989
+ "valley, vale": 979,
1990
+ "vase": 883,
1991
+ "vault": 884,
1992
+ "velvet": 885,
1993
+ "vending machine": 886,
1994
+ "vestment": 887,
1995
+ "viaduct": 888,
1996
+ "vine snake": 59,
1997
+ "violin, fiddle": 889,
1998
+ "vizsla, Hungarian pointer": 211,
1999
+ "volcano": 980,
2000
+ "volleyball": 890,
2001
+ "vulture": 23,
2002
+ "waffle iron": 891,
2003
+ "walking stick, walkingstick, stick insect": 313,
2004
+ "wall clock": 892,
2005
+ "wallaby, brush kangaroo": 104,
2006
+ "wallet, billfold, notecase, pocketbook": 893,
2007
+ "wardrobe, closet, press": 894,
2008
+ "warplane, military plane": 895,
2009
+ "warthog": 343,
2010
+ "washbasin, handbasin, washbowl, lavabo, wash-hand basin": 896,
2011
+ "washer, automatic washer, washing machine": 897,
2012
+ "water bottle": 898,
2013
+ "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis": 346,
2014
+ "water jug": 899,
2015
+ "water ouzel, dipper": 20,
2016
+ "water snake": 58,
2017
+ "water tower": 900,
2018
+ "weasel": 356,
2019
+ "web site, website, internet site, site": 916,
2020
+ "weevil": 307,
2021
+ "whippet": 172,
2022
+ "whiptail, whiptail lizard": 41,
2023
+ "whiskey jug": 901,
2024
+ "whistle": 902,
2025
+ "white stork, Ciconia ciconia": 127,
2026
+ "white wolf, Arctic wolf, Canis lupus tundrarum": 270,
2027
+ "wig": 903,
2028
+ "wild boar, boar, Sus scrofa": 342,
2029
+ "window screen": 904,
2030
+ "window shade": 905,
2031
+ "wine bottle": 907,
2032
+ "wing": 908,
2033
+ "wire-haired fox terrier": 188,
2034
+ "wok": 909,
2035
+ "wolf spider, hunting spider": 77,
2036
+ "wombat": 106,
2037
+ "wood rabbit, cottontail, cottontail rabbit": 330,
2038
+ "wooden spoon": 910,
2039
+ "wool, woolen, woollen": 911,
2040
+ "worm fence, snake fence, snake-rail fence, Virginia fence": 912,
2041
+ "wreck": 913,
2042
+ "yawl": 914,
2043
+ "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum": 986,
2044
+ "yurt": 915,
2045
+ "zebra": 340,
2046
+ "zucchini, courgette": 939
2047
+ },
2048
+ "layer_norm_eps": 1e-05,
2049
+ "length_penalty": 1.0,
2050
+ "max_length": 20,
2051
+ "min_length": 0,
2052
+ "mlp_ratio": 4.0,
2053
+ "model_type": "swin",
2054
+ "no_repeat_ngram_size": 0,
2055
+ "num_beam_groups": 1,
2056
+ "num_beams": 1,
2057
+ "num_channels": 3,
2058
+ "num_heads": [
2059
+ 3,
2060
+ 6,
2061
+ 12,
2062
+ 24
2063
+ ],
2064
+ "num_layers": 4,
2065
+ "num_return_sequences": 1,
2066
+ "out_features": [
2067
+ "stage1",
2068
+ "stage2",
2069
+ "stage3",
2070
+ "stage4"
2071
+ ],
2072
+ "output_attentions": false,
2073
+ "output_hidden_states": false,
2074
+ "output_scores": false,
2075
+ "pad_token_id": null,
2076
+ "patch_size": 4,
2077
+ "path_norm": true,
2078
+ "prefix": null,
2079
+ "problem_type": null,
2080
+ "pruned_heads": {},
2081
+ "qkv_bias": true,
2082
+ "remove_invalid_values": false,
2083
+ "repetition_penalty": 1.0,
2084
+ "return_dict": true,
2085
+ "return_dict_in_generate": false,
2086
+ "sep_token_id": null,
2087
+ "stage_names": [
2088
+ "stem",
2089
+ "stage1",
2090
+ "stage2",
2091
+ "stage3",
2092
+ "stage4"
2093
+ ],
2094
+ "suppress_tokens": null,
2095
+ "task_specific_params": null,
2096
+ "temperature": 1.0,
2097
+ "tf_legacy_loss": false,
2098
+ "tie_encoder_decoder": false,
2099
+ "tie_word_embeddings": true,
2100
+ "tokenizer_class": null,
2101
+ "top_k": 50,
2102
+ "top_p": 1.0,
2103
+ "torch_dtype": "float32",
2104
+ "torchscript": false,
2105
+ "transformers_version": "4.26.0.dev0",
2106
+ "typical_p": 1.0,
2107
+ "use_absolute_embeddings": false,
2108
+ "use_bfloat16": false,
2109
+ "window_size": 7
2110
+ },
2111
+ "class_weight": 2.0,
2112
+ "common_stride": 4,
2113
+ "decoder_layers": 10,
2114
+ "dice_weight": 5.0,
2115
+ "dim_feedforward": 2048,
2116
+ "dropout": 0.0,
2117
+ "encoder_feedforward_dim": 1024,
2118
+ "encoder_layers": 6,
2119
+ "enforce_input_proj": false,
2120
+ "enforce_input_projection": false,
2121
+ "feature_size": 256,
2122
+ "feature_strides": [
2123
+ 4,
2124
+ 8,
2125
+ 16,
2126
+ 32
2127
+ ],
2128
+ "hidden_dim": 256,
2129
+ "id2label": {
2130
+ "0": "person",
2131
+ "1": "bicycle",
2132
+ "2": "car",
2133
+ "3": "motorbike",
2134
+ "4": "aeroplane",
2135
+ "5": "bus",
2136
+ "6": "train",
2137
+ "7": "truck",
2138
+ "8": "boat",
2139
+ "9": "traffic light",
2140
+ "10": "fire hydrant",
2141
+ "11": "stop sign",
2142
+ "12": "parking meter",
2143
+ "13": "bench",
2144
+ "14": "bird",
2145
+ "15": "cat",
2146
+ "16": "dog",
2147
+ "17": "horse",
2148
+ "18": "sheep",
2149
+ "19": "cow",
2150
+ "20": "elephant",
2151
+ "21": "bear",
2152
+ "22": "zebra",
2153
+ "23": "giraffe",
2154
+ "24": "backpack",
2155
+ "25": "umbrella",
2156
+ "26": "handbag",
2157
+ "27": "tie",
2158
+ "28": "suitcase",
2159
+ "29": "frisbee",
2160
+ "30": "skis",
2161
+ "31": "snowboard",
2162
+ "32": "sports ball",
2163
+ "33": "kite",
2164
+ "34": "baseball bat",
2165
+ "35": "baseball glove",
2166
+ "36": "skateboard",
2167
+ "37": "surfboard",
2168
+ "38": "tennis racket",
2169
+ "39": "bottle",
2170
+ "40": "wine glass",
2171
+ "41": "cup",
2172
+ "42": "fork",
2173
+ "43": "knife",
2174
+ "44": "spoon",
2175
+ "45": "bowl",
2176
+ "46": "banana",
2177
+ "47": "apple",
2178
+ "48": "sandwich",
2179
+ "49": "orange",
2180
+ "50": "broccoli",
2181
+ "51": "carrot",
2182
+ "52": "hot dog",
2183
+ "53": "pizza",
2184
+ "54": "donut",
2185
+ "55": "cake",
2186
+ "56": "chair",
2187
+ "57": "sofa",
2188
+ "58": "pottedplant",
2189
+ "59": "bed",
2190
+ "60": "diningtable",
2191
+ "61": "toilet",
2192
+ "62": "tvmonitor",
2193
+ "63": "laptop",
2194
+ "64": "mouse",
2195
+ "65": "remote",
2196
+ "66": "keyboard",
2197
+ "67": "cell phone",
2198
+ "68": "microwave",
2199
+ "69": "oven",
2200
+ "70": "toaster",
2201
+ "71": "sink",
2202
+ "72": "refrigerator",
2203
+ "73": "book",
2204
+ "74": "clock",
2205
+ "75": "vase",
2206
+ "76": "scissors",
2207
+ "77": "teddy bear",
2208
+ "78": "hair drier",
2209
+ "79": "toothbrush"
2210
+ },
2211
+ "ignore_value": 255,
2212
+ "importance_sample_ratio": 0.75,
2213
+ "init_std": 0.02,
2214
+ "init_xavier_std": 1.0,
2215
+ "label2id": {
2216
+ "aeroplane": 4,
2217
+ "apple": 47,
2218
+ "backpack": 24,
2219
+ "banana": 46,
2220
+ "baseball bat": 34,
2221
+ "baseball glove": 35,
2222
+ "bear": 21,
2223
+ "bed": 59,
2224
+ "bench": 13,
2225
+ "bicycle": 1,
2226
+ "bird": 14,
2227
+ "boat": 8,
2228
+ "book": 73,
2229
+ "bottle": 39,
2230
+ "bowl": 45,
2231
+ "broccoli": 50,
2232
+ "bus": 5,
2233
+ "cake": 55,
2234
+ "car": 2,
2235
+ "carrot": 51,
2236
+ "cat": 15,
2237
+ "cell phone": 67,
2238
+ "chair": 56,
2239
+ "clock": 74,
2240
+ "cow": 19,
2241
+ "cup": 41,
2242
+ "diningtable": 60,
2243
+ "dog": 16,
2244
+ "donut": 54,
2245
+ "elephant": 20,
2246
+ "fire hydrant": 10,
2247
+ "fork": 42,
2248
+ "frisbee": 29,
2249
+ "giraffe": 23,
2250
+ "hair drier": 78,
2251
+ "handbag": 26,
2252
+ "horse": 17,
2253
+ "hot dog": 52,
2254
+ "keyboard": 66,
2255
+ "kite": 33,
2256
+ "knife": 43,
2257
+ "laptop": 63,
2258
+ "microwave": 68,
2259
+ "motorbike": 3,
2260
+ "mouse": 64,
2261
+ "orange": 49,
2262
+ "oven": 69,
2263
+ "parking meter": 12,
2264
+ "person": 0,
2265
+ "pizza": 53,
2266
+ "pottedplant": 58,
2267
+ "refrigerator": 72,
2268
+ "remote": 65,
2269
+ "sandwich": 48,
2270
+ "scissors": 76,
2271
+ "sheep": 18,
2272
+ "sink": 71,
2273
+ "skateboard": 36,
2274
+ "skis": 30,
2275
+ "snowboard": 31,
2276
+ "sofa": 57,
2277
+ "spoon": 44,
2278
+ "sports ball": 32,
2279
+ "stop sign": 11,
2280
+ "suitcase": 28,
2281
+ "surfboard": 37,
2282
+ "teddy bear": 77,
2283
+ "tennis racket": 38,
2284
+ "tie": 27,
2285
+ "toaster": 70,
2286
+ "toilet": 61,
2287
+ "toothbrush": 79,
2288
+ "traffic light": 9,
2289
+ "train": 6,
2290
+ "truck": 7,
2291
+ "tvmonitor": 62,
2292
+ "umbrella": 25,
2293
+ "vase": 75,
2294
+ "wine glass": 40,
2295
+ "zebra": 22
2296
+ },
2297
+ "mask_feature_size": 256,
2298
+ "mask_weight": 5.0,
2299
+ "model_type": "mask2former",
2300
+ "no_object_weight": 0.1,
2301
+ "num_attention_heads": 8,
2302
+ "num_hidden_layers": 10,
2303
+ "num_queries": 100,
2304
+ "output_auxiliary_logits": null,
2305
+ "oversample_ratio": 3.0,
2306
+ "pre_norm": false,
2307
+ "torch_dtype": "float32",
2308
+ "train_num_points": 12544,
2309
+ "transformers_version": null,
2310
+ "use_auxiliary_loss": true
2311
+ }
configs/vhap_tracking/base_tracking_config.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "# tyro YAML.\n!dataclass:BaseTrackingConfig\nasync_func: true\nbegin_frame_idx: 0\n\
2
+ begin_stage: null\ndata: !dataclass:DataConfig\n _target: vhap.data.video_dataset.VideoDataset\n\
3
+ \ align_cameras_to_axes: true\n background_color: white\n calibrated: false\n\
4
+ \ camera_convention_conversion: opencv->opengl\n division: null\n landmark_source:\
5
+ \ star\n n_downsample_rgb: null\n root_folder: ''\n scale_factor: 1.0\n sequence:\
6
+ \ ''\n subset: null\n target_extrinsic_type: w2c\n use_alpha_map: false\n use_landmark:\
7
+ \ true\ndevice: cuda\nexp: !dataclass:ExperimentConfig\n keyframes: !!python/tuple\
8
+ \ []\n output_folder: !!python/object/apply:pathlib.PosixPath\n - output\n -\
9
+ \ track\n photometric: false\n reuse_landmarks: true\nlog: !dataclass:LogConfig\n\
10
+ \ image_format: jpg\n interval_media: 500\n interval_scalar: 100\n max_num_views:\
11
+ \ 3\n stack_views_in_rows: true\n view_indices: !!python/tuple []\nlr: !dataclass:LearningRateConfig\n\
12
+ \ base: 0.005\n camera: 0.005\n dynamic_offset: 0.0005\n expr: 0.05\n light:\
13
+ \ 0.005\n static_offset: 0.0005\n translation: 0.001\nmodel: !dataclass:ModelConfig\n\
14
+ \ add_teeth: true\n flame_params_path: null\n n_expr: 100\n n_shape: 300\n \
15
+ \ n_tex: 100\n occluded: !!python/tuple\n - hair\n remove_lip_inside: false\n\
16
+ \ residual_tex: true\n tex_clusters: !!python/tuple\n - skin\n - hair\n - boundary\n\
17
+ \ - lips_tight\n - teeth\n - sclerae\n - irises\n tex_extra: true\n tex_painted:\
18
+ \ true\n tex_resolution: 2048\n use_dynamic_offset: false\n use_static_offset:\
19
+ \ false\npipeline: !dataclass:PipelineConfig\n lmk_global_tracking: !dataclass:StageLmkGlobalTrackingConfig\n\
20
+ \ disable_jawline_landmarks: false\n num_epochs: 0\n optimizable_params:\
21
+ \ &id001 !!python/tuple\n - cam\n - pose\n - shape\n - joints\n -\
22
+ \ expr\n lmk_init_all: !dataclass:StageLmkInitAllConfig\n disable_jawline_landmarks:\
23
+ \ false\n num_steps: 300\n optimizable_params: *id001\n lmk_init_rigid: !dataclass:StageLmkInitRigidConfig\n\
24
+ \ disable_jawline_landmarks: false\n num_steps: 300\n optimizable_params:\
25
+ \ !!python/tuple\n - cam\n - pose\n lmk_sequential_tracking: !dataclass:StageLmkSequentialTrackingConfig\n\
26
+ \ disable_jawline_landmarks: false\n num_steps: 50\n optimizable_params:\
27
+ \ !!python/tuple\n - pose\n - joints\n - expr\n rgb_global_tracking: !dataclass:StageRgbGlobalTrackingConfig\n\
28
+ \ align_boundary_except: !!python/tuple\n - bottomline\n - hair\n align_texture_except:\
29
+ \ !!python/tuple\n - hair\n disable_jawline_landmarks: true\n num_epochs:\
30
+ \ 30\n optimizable_params: !!python/tuple\n - cam\n - pose\n - shape\n\
31
+ \ - joints\n - expr\n - texture\n - lights\n - static_offset\n \
32
+ \ - dynamic_offset\n rgb_init_all: !dataclass:StageRgbInitAllConfig\n align_boundary_except:\
33
+ \ !!python/tuple\n - hair\n - bottomline\n - hair\n align_texture_except:\
34
+ \ !!python/tuple\n - hair\n - boundary\n - neck\n - hair\n disable_jawline_landmarks:\
35
+ \ true\n num_steps: 500\n optimizable_params: !!python/tuple\n - cam\n\
36
+ \ - pose\n - shape\n - joints\n - expr\n - texture\n - lights\n\
37
+ \ rgb_init_offset: !dataclass:StageRgbInitOffsetConfig\n align_boundary_except:\
38
+ \ !!python/tuple\n - bottomline\n - hair\n align_texture_except: !!python/tuple\n\
39
+ \ - hair\n - boundary\n - neck\n - hair\n disable_jawline_landmarks:\
40
+ \ true\n num_steps: 500\n optimizable_params: !!python/tuple\n - cam\n\
41
+ \ - pose\n - shape\n - joints\n - expr\n - texture\n - lights\n\
42
+ \ - static_offset\n rgb_init_texture: !dataclass:StageRgbInitTextureConfig\n\
43
+ \ align_boundary_except: !!python/tuple\n - hair\n - boundary\n - hair\n\
44
+ \ align_texture_except: !!python/tuple\n - hair\n - boundary\n - neck\n\
45
+ \ - hair\n disable_jawline_landmarks: false\n num_steps: 500\n optimizable_params:\
46
+ \ !!python/tuple\n - cam\n - shape\n - texture\n - lights\n rgb_sequential_tracking:\
47
+ \ !dataclass:StageRgbSequentialTrackingConfig\n align_boundary_except: !!python/tuple\n\
48
+ \ - bottomline\n - hair\n align_texture_except: !!python/tuple\n - hair\n\
49
+ \ disable_jawline_landmarks: true\n num_steps: 50\n optimizable_params:\
50
+ \ !!python/tuple\n - pose\n - joints\n - expr\n - texture\n - dynamic_offset\n\
51
+ render: !dataclass:RenderConfig\n backend: nvdiffrast\n background_eval: target\n\
52
+ \ background_train: target\n disturb_rate_bg: 0.5\n disturb_rate_fg: 0.5\n lighting_space:\
53
+ \ world\n lighting_type: SH\n use_opengl: false\nw: !dataclass:LossWeightConfig\n\
54
+ \ always_enable_jawline_landmarks: true\n blur_iter: 0\n landmark: 10.0\n photo:\
55
+ \ 30.0\n prior_eyes: 0.03\n prior_jaw: 0.3\n prior_neck: 0.3\n reg_diffuse:\
56
+ \ 100.0\n reg_expr: 0.03\n reg_light: null\n reg_offset: 300.0\n reg_offset_dynamic:\
57
+ \ 300000.0\n reg_offset_lap: 1000000.0\n reg_offset_lap_relax_coef: 0.1\n reg_offset_lap_relax_for:\
58
+ \ &id002 !!python/tuple\n - hair\n - ears\n reg_offset_relax_coef: 1.0\n reg_offset_relax_for:\
59
+ \ *id002\n reg_offset_rigid: 300.0\n reg_offset_rigid_for: !!python/tuple\n -\
60
+ \ left_ear\n - right_ear\n - neck\n - left_eye\n - right_eye\n - lips_tight\n\
61
+ \ reg_shape: 0.3\n reg_tex_pca: 0.0001\n reg_tex_res: null\n reg_tex_res_clusters:\
62
+ \ 10.0\n reg_tex_res_for: !!python/tuple\n - sclerae\n - teeth\n reg_tex_tv:\
63
+ \ 10000.0\n smooth_eyes: 0\n smooth_jaw: 0.1\n smooth_neck: 30.0\n smooth_rot:\
64
+ \ 30.0\n smooth_trans: 300.0\n"
external/human_matting/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .matting_engine import StyleMatteEngine
external/human_matting/matting_engine.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import inspect
4
+ import warnings
5
+ import torchvision
6
+ from .stylematte import StyleMatte
7
+
8
+ class StyleMatteEngine(torch.nn.Module):
9
+ def __init__(self, device='cpu',human_matting_path='./model_zoo/flame_tracking_models/matting/stylematte_synth.pt'):
10
+ super().__init__()
11
+ self._device = device
12
+ self.normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
13
+ self._init_models(human_matting_path)
14
+
15
+ def _init_models(self,_ckpt_path):
16
+ # load dict
17
+ state_dict = torch.load(_ckpt_path, map_location='cpu')
18
+ # build model
19
+ model = StyleMatte()
20
+ model.load_state_dict(state_dict)
21
+ self.model = model.to(self._device).eval()
22
+
23
+ @torch.no_grad()
24
+ def forward(self, input_image, return_type='matting', background_rgb=1.0):
25
+ if not hasattr(self, 'model'):
26
+ self._init_models()
27
+ if input_image.max() > 2.0:
28
+ warnings.warn('Image should be normalized to [0, 1].')
29
+ _, ori_h, ori_w = input_image.shape
30
+ input_image = input_image.to(self._device).float()
31
+ image = input_image.clone()
32
+ # resize
33
+ if max(ori_h, ori_w) > 1024:
34
+ scale = 1024.0 / max(ori_h, ori_w)
35
+ resized_h, resized_w = int(ori_h * scale), int(ori_w * scale)
36
+ image = torchvision.transforms.functional.resize(image, (resized_h, resized_w), antialias=True)
37
+ else:
38
+ resized_h, resized_w = ori_h, ori_w
39
+ # padding
40
+ if resized_h % 8 != 0 or resized_w % 8 != 0:
41
+ image = torchvision.transforms.functional.pad(image, ((8-resized_w % 8)%8, (8-resized_h % 8)%8, 0, 0, ), padding_mode='reflect')
42
+ # normalize and forwarding
43
+ image = self.normalize(image)[None]
44
+ predict = self.model(image)[0]
45
+ # undo padding
46
+ predict = predict[:, -resized_h:, -resized_w:]
47
+ # undo resize
48
+ if resized_h != ori_h or resized_w != ori_w:
49
+ predict = torchvision.transforms.functional.resize(predict, (ori_h, ori_w), antialias=True)
50
+
51
+ if return_type == 'alpha':
52
+ return predict[0]
53
+ elif return_type == 'matting':
54
+ predict = predict.expand(3, -1, -1)
55
+ matting_image = input_image.clone()
56
+ background_rgb = matting_image.new_ones(matting_image.shape) * background_rgb
57
+ matting_image = matting_image * predict + (1-predict) * background_rgb
58
+ return matting_image, predict[0]
59
+ elif return_type == 'all':
60
+ predict = predict.expand(3, -1, -1)
61
+ background_rgb = input_image.new_ones(input_image.shape) * background_rgb
62
+ foreground_image = input_image * predict + (1-predict) * background_rgb
63
+ background_image = input_image * (1-predict) + predict * background_rgb
64
+ return foreground_image, background_image
65
+ else:
66
+ raise NotImplementedError
external/human_matting/stylematte.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from transformers import Mask2FormerForUniversalSegmentation
6
+ from transformers.models.mask2former.configuration_mask2former import Mask2FormerConfig
7
+
8
+ class StyleMatte(nn.Module):
9
+ def __init__(self):
10
+ super(StyleMatte, self).__init__()
11
+ self.fpn = FPN_fuse(feature_channels=[256, 256, 256, 256], fpn_out=256)
12
+ config = Mask2FormerConfig.from_json_file('./configs/stylematte_config.json')
13
+ self.pixel_decoder = Mask2FormerForUniversalSegmentation(config).base_model.pixel_level_module
14
+ self.fgf = FastGuidedFilter(eps=1e-4)
15
+ self.conv = nn.Conv2d(256, 1, kernel_size=3, padding=1)
16
+
17
+ def forward(self, image, normalize=False):
18
+ decoder_out = self.pixel_decoder(image)
19
+ decoder_states = list(decoder_out.decoder_hidden_states)
20
+ decoder_states.append(decoder_out.decoder_last_hidden_state)
21
+ out_pure = self.fpn(decoder_states)
22
+
23
+ image_lr = nn.functional.interpolate(image.mean(1, keepdim=True),
24
+ scale_factor=0.25,
25
+ mode='bicubic',
26
+ align_corners=True
27
+ )
28
+ out = self.conv(out_pure)
29
+ out = self.fgf(image_lr, out, image.mean(1, keepdim=True))
30
+
31
+ return torch.sigmoid(out)
32
+
33
+ def get_training_params(self):
34
+ return list(self.fpn.parameters())+list(self.conv.parameters())
35
+
36
+
37
+ def conv2d_relu(input_filters, output_filters, kernel_size=3, bias=True):
38
+ return nn.Sequential(
39
+ nn.Conv2d(input_filters, output_filters,
40
+ kernel_size=kernel_size, padding=kernel_size//2, bias=bias),
41
+ nn.LeakyReLU(0.2, inplace=True),
42
+ nn.BatchNorm2d(output_filters)
43
+ )
44
+
45
+
46
+ def up_and_add(x, y):
47
+ return F.interpolate(x, size=(y.size(2), y.size(3)), mode='bilinear', align_corners=True) + y
48
+
49
+
50
+ class FPN_fuse(nn.Module):
51
+ def __init__(self, feature_channels=[256, 512, 1024, 2048], fpn_out=256):
52
+ super(FPN_fuse, self).__init__()
53
+ assert feature_channels[0] == fpn_out
54
+ self.conv1x1 = nn.ModuleList([nn.Conv2d(ft_size, fpn_out, kernel_size=1)
55
+ for ft_size in feature_channels[1:]])
56
+ self.smooth_conv = nn.ModuleList([nn.Conv2d(fpn_out, fpn_out, kernel_size=3, padding=1)]
57
+ * (len(feature_channels)-1))
58
+ self.conv_fusion = nn.Sequential(
59
+ nn.Conv2d(2*fpn_out, fpn_out, kernel_size=3,
60
+ padding=1, bias=False),
61
+ nn.BatchNorm2d(fpn_out),
62
+ nn.ReLU(inplace=True),
63
+ )
64
+
65
+ def forward(self, features):
66
+
67
+ features[:-1] = [conv1x1(feature) for feature,
68
+ conv1x1 in zip(features[:-1], self.conv1x1)]
69
+ feature = up_and_add(self.smooth_conv[0](features[0]), features[1])
70
+ feature = up_and_add(self.smooth_conv[1](feature), features[2])
71
+ feature = up_and_add(self.smooth_conv[2](feature), features[3])
72
+
73
+ H, W = features[-1].size(2), features[-1].size(3)
74
+ x = [feature, features[-1]]
75
+ x = [F.interpolate(x_el, size=(H, W), mode='bilinear',
76
+ align_corners=True) for x_el in x]
77
+
78
+ x = self.conv_fusion(torch.cat(x, dim=1))
79
+
80
+ return x
81
+
82
+
83
+ class PSPModule(nn.Module):
84
+ # In the original inmplementation they use precise RoI pooling
85
+ # Instead of using adaptative average pooling
86
+ def __init__(self, in_channels, bin_sizes=[1, 2, 4, 6]):
87
+ super(PSPModule, self).__init__()
88
+ out_channels = in_channels // len(bin_sizes)
89
+ self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s)
90
+ for b_s in bin_sizes])
91
+ self.bottleneck = nn.Sequential(
92
+ nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), in_channels,
93
+ kernel_size=3, padding=1, bias=False),
94
+ nn.BatchNorm2d(in_channels),
95
+ nn.ReLU(inplace=True),
96
+ nn.Dropout2d(0.1)
97
+ )
98
+
99
+ def _make_stages(self, in_channels, out_channels, bin_sz):
100
+ prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
101
+ conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
102
+ bn = nn.BatchNorm2d(out_channels)
103
+ relu = nn.ReLU(inplace=True)
104
+ return nn.Sequential(prior, conv, bn, relu)
105
+
106
+ def forward(self, features):
107
+ h, w = features.size()[2], features.size()[3]
108
+ pyramids = [features]
109
+ pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear',
110
+ align_corners=True) for stage in self.stages])
111
+ output = self.bottleneck(torch.cat(pyramids, dim=1))
112
+ return output
113
+
114
+
115
+ class GuidedFilter(nn.Module):
116
+ def __init__(self, r, eps=1e-8):
117
+ super(GuidedFilter, self).__init__()
118
+
119
+ self.r = r
120
+ self.eps = eps
121
+ self.boxfilter = BoxFilter(r)
122
+
123
+ def forward(self, x, y):
124
+ n_x, c_x, h_x, w_x = x.size()
125
+ n_y, c_y, h_y, w_y = y.size()
126
+
127
+ assert n_x == n_y
128
+ assert c_x == 1 or c_x == c_y
129
+ assert h_x == h_y and w_x == w_y
130
+ assert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1
131
+
132
+ # N
133
+ N = self.boxfilter((x.data.new().resize_((1, 1, h_x, w_x)).fill_(1.0)))
134
+
135
+ # mean_x
136
+ mean_x = self.boxfilter(x) / N
137
+ # mean_y
138
+ mean_y = self.boxfilter(y) / N
139
+ # cov_xy
140
+ cov_xy = self.boxfilter(x * y) / N - mean_x * mean_y
141
+ # var_x
142
+ var_x = self.boxfilter(x * x) / N - mean_x * mean_x
143
+
144
+ # A
145
+ A = cov_xy / (var_x + self.eps)
146
+ # b
147
+ b = mean_y - A * mean_x
148
+
149
+ # mean_A; mean_b
150
+ mean_A = self.boxfilter(A) / N
151
+ mean_b = self.boxfilter(b) / N
152
+
153
+ return mean_A * x + mean_b
154
+
155
+
156
+ class FastGuidedFilter(nn.Module):
157
+ def __init__(self, r=1, eps=1e-8):
158
+ super(FastGuidedFilter, self).__init__()
159
+
160
+ self.r = r
161
+ self.eps = eps
162
+ self.boxfilter = BoxFilter(r)
163
+
164
+ def forward(self, lr_x, lr_y, hr_x):
165
+ n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()
166
+ n_lry, c_lry, h_lry, w_lry = lr_y.size()
167
+ n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()
168
+
169
+ assert n_lrx == n_lry and n_lry == n_hrx
170
+ assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)
171
+ assert h_lrx == h_lry and w_lrx == w_lry
172
+ assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1
173
+
174
+ # N
175
+ N = self.boxfilter(lr_x.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0))
176
+
177
+ # mean_x
178
+ mean_x = self.boxfilter(lr_x) / N
179
+ # mean_y
180
+ mean_y = self.boxfilter(lr_y) / N
181
+ # cov_xy
182
+ cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y
183
+ # var_x
184
+ var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x
185
+
186
+ # A
187
+ A = cov_xy / (var_x + self.eps)
188
+ # b
189
+ b = mean_y - A * mean_x
190
+
191
+ # mean_A; mean_b
192
+ mean_A = F.interpolate(
193
+ A, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
194
+ mean_b = F.interpolate(
195
+ b, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
196
+
197
+ return mean_A*hr_x+mean_b
198
+
199
+
200
+ class DeepGuidedFilterRefiner(nn.Module):
201
+ def __init__(self, hid_channels=16):
202
+ super().__init__()
203
+ self.box_filter = nn.Conv2d(
204
+ 4, 4, kernel_size=3, padding=1, bias=False, groups=4)
205
+ self.box_filter.weight.data[...] = 1 / 9
206
+ self.conv = nn.Sequential(
207
+ nn.Conv2d(4 * 2 + hid_channels, hid_channels,
208
+ kernel_size=1, bias=False),
209
+ nn.BatchNorm2d(hid_channels),
210
+ nn.ReLU(True),
211
+ nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False),
212
+ nn.BatchNorm2d(hid_channels),
213
+ nn.ReLU(True),
214
+ nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True)
215
+ )
216
+
217
+ def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
218
+ fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1)
219
+ base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1)
220
+ base_y = torch.cat([base_fgr, base_pha], dim=1)
221
+
222
+ mean_x = self.box_filter(base_x)
223
+ mean_y = self.box_filter(base_y)
224
+ cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y
225
+ var_x = self.box_filter(base_x * base_x) - mean_x * mean_x
226
+
227
+ A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1))
228
+ b = mean_y - A * mean_x
229
+
230
+ H, W = fine_src.shape[2:]
231
+ A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False)
232
+ b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False)
233
+
234
+ out = A * fine_x + b
235
+ fgr, pha = out.split([3, 1], dim=1)
236
+ return fgr, pha
237
+
238
+
239
+ def diff_x(input, r):
240
+ assert input.dim() == 4
241
+
242
+ left = input[:, :, r:2 * r + 1]
243
+ middle = input[:, :, 2 * r + 1:] - input[:, :, :-2 * r - 1]
244
+ right = input[:, :, -1:] - input[:, :, -2 * r - 1: -r - 1]
245
+
246
+ output = torch.cat([left, middle, right], dim=2)
247
+
248
+ return output
249
+
250
+
251
+ def diff_y(input, r):
252
+ assert input.dim() == 4
253
+
254
+ left = input[:, :, :, r:2 * r + 1]
255
+ middle = input[:, :, :, 2 * r + 1:] - input[:, :, :, :-2 * r - 1]
256
+ right = input[:, :, :, -1:] - input[:, :, :, -2 * r - 1: -r - 1]
257
+
258
+ output = torch.cat([left, middle, right], dim=3)
259
+
260
+ return output
261
+
262
+
263
+ class BoxFilter(nn.Module):
264
+ def __init__(self, r):
265
+ super(BoxFilter, self).__init__()
266
+
267
+ self.r = r
268
+
269
+ def forward(self, x):
270
+ assert x.dim() == 4
271
+
272
+ return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
external/landmark_detection/FaceBoxesV2/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import detector
2
+ from . import faceboxes_detector
external/landmark_detection/FaceBoxesV2/detector.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+ class Detector(object):
4
+ def __init__(self, model_arch, model_weights):
5
+ self.model_arch = model_arch
6
+ self.model_weights = model_weights
7
+
8
+ def detect(self, image, thresh):
9
+ raise NotImplementedError
10
+
11
+ def crop(self, image, detections):
12
+ crops = []
13
+ for det in detections:
14
+ xmin = max(det[2], 0)
15
+ ymin = max(det[3], 0)
16
+ width = det[4]
17
+ height = det[5]
18
+ xmax = min(xmin+width, image.shape[1])
19
+ ymax = min(ymin+height, image.shape[0])
20
+ cut = image[ymin:ymax, xmin:xmax,:]
21
+ crops.append(cut)
22
+
23
+ return crops
24
+
25
+ def draw(self, image, detections, im_scale=None):
26
+ if im_scale is not None:
27
+ image = cv2.resize(image, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
28
+ detections = [[det[0],det[1],int(det[2]*im_scale),int(det[3]*im_scale),int(det[4]*im_scale),int(det[5]*im_scale)] for det in detections]
29
+
30
+ for det in detections:
31
+ xmin = det[2]
32
+ ymin = det[3]
33
+ width = det[4]
34
+ height = det[5]
35
+ xmax = xmin + width
36
+ ymax = ymin + height
37
+ cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 0, 255), 2)
38
+
39
+ return image
external/landmark_detection/FaceBoxesV2/faceboxes_detector.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .detector import Detector
2
+ import cv2, os
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from .utils.config import cfg
7
+ from .utils.prior_box import PriorBox
8
+ from .utils.nms_wrapper import nms
9
+ from .utils.faceboxes import FaceBoxesV2
10
+ from .utils.box_utils import decode
11
+ import time
12
+
13
+ class FaceBoxesDetector(Detector):
14
+ def __init__(self, model_arch, model_weights, use_gpu, device):
15
+ super().__init__(model_arch, model_weights)
16
+ self.name = 'FaceBoxesDetector'
17
+ self.net = FaceBoxesV2(phase='test', size=None, num_classes=2) # initialize detector
18
+ self.use_gpu = use_gpu
19
+ self.device = device
20
+
21
+ state_dict = torch.load(self.model_weights, map_location=self.device)
22
+ # create new OrderedDict that does not contain `module.`
23
+ from collections import OrderedDict
24
+ new_state_dict = OrderedDict()
25
+ for k, v in state_dict.items():
26
+ name = k[7:] # remove `module.`
27
+ new_state_dict[name] = v
28
+ # load params
29
+ self.net.load_state_dict(new_state_dict)
30
+ self.net = self.net.to(self.device)
31
+ self.net.eval()
32
+
33
+
34
+ def detect(self, image, thresh=0.6, im_scale=None):
35
+ # auto resize for large images
36
+ if im_scale is None:
37
+ height, width, _ = image.shape
38
+ if min(height, width) > 600:
39
+ im_scale = 600. / min(height, width)
40
+ else:
41
+ im_scale = 1
42
+ image_scale = cv2.resize(image, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
43
+
44
+ scale = torch.Tensor([image_scale.shape[1], image_scale.shape[0], image_scale.shape[1], image_scale.shape[0]])
45
+ image_scale = torch.from_numpy(image_scale.transpose(2,0,1)).to(self.device).int()
46
+ mean_tmp = torch.IntTensor([104, 117, 123]).to(self.device)
47
+ mean_tmp = mean_tmp.unsqueeze(1).unsqueeze(2)
48
+ image_scale -= mean_tmp
49
+ image_scale = image_scale.float().unsqueeze(0)
50
+ scale = scale.to(self.device)
51
+
52
+ with torch.no_grad():
53
+ out = self.net(image_scale)
54
+ #priorbox = PriorBox(cfg, out[2], (image_scale.size()[2], image_scale.size()[3]), phase='test')
55
+ priorbox = PriorBox(cfg, image_size=(image_scale.size()[2], image_scale.size()[3]))
56
+ priors = priorbox.forward()
57
+ priors = priors.to(self.device)
58
+ loc, conf = out
59
+ prior_data = priors.data
60
+ boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
61
+ boxes = boxes * scale
62
+ boxes = boxes.cpu().numpy()
63
+ scores = conf.data.cpu().numpy()[:, 1]
64
+
65
+ # ignore low scores
66
+ inds = np.where(scores > thresh)[0]
67
+ boxes = boxes[inds]
68
+ scores = scores[inds]
69
+
70
+ # keep top-K before NMS
71
+ order = scores.argsort()[::-1][:5000]
72
+ boxes = boxes[order]
73
+ scores = scores[order]
74
+
75
+ # do NMS
76
+ dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
77
+ keep = nms(dets, 0.3)
78
+ dets = dets[keep, :]
79
+
80
+ dets = dets[:750, :]
81
+ detections_scale = []
82
+ for i in range(dets.shape[0]):
83
+ xmin = int(dets[i][0])
84
+ ymin = int(dets[i][1])
85
+ xmax = int(dets[i][2])
86
+ ymax = int(dets[i][3])
87
+ score = dets[i][4]
88
+ width = xmax - xmin
89
+ height = ymax - ymin
90
+ detections_scale.append(['face', score, xmin, ymin, width, height])
91
+
92
+ # adapt bboxes to the original image size
93
+ if len(detections_scale) > 0:
94
+ detections_scale = [[det[0],det[1],int(det[2]/im_scale),int(det[3]/im_scale),int(det[4]/im_scale),int(det[5]/im_scale)] for det in detections_scale]
95
+
96
+ return detections_scale, im_scale
97
+
external/landmark_detection/FaceBoxesV2/utils/__init__.py ADDED
File without changes
external/landmark_detection/FaceBoxesV2/utils/box_utils.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def point_form(boxes):
6
+ """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
7
+ representation for comparison to point form ground truth data.
8
+ Args:
9
+ boxes: (tensor) center-size default boxes from priorbox layers.
10
+ Return:
11
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
12
+ """
13
+ return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin
14
+ boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax
15
+
16
+
17
+ def center_size(boxes):
18
+ """ Convert prior_boxes to (cx, cy, w, h)
19
+ representation for comparison to center-size form ground truth data.
20
+ Args:
21
+ boxes: (tensor) point_form boxes
22
+ Return:
23
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
24
+ """
25
+ return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy
26
+ boxes[:, 2:] - boxes[:, :2], 1) # w, h
27
+
28
+
29
+ def intersect(box_a, box_b):
30
+ """ We resize both tensors to [A,B,2] without new malloc:
31
+ [A,2] -> [A,1,2] -> [A,B,2]
32
+ [B,2] -> [1,B,2] -> [A,B,2]
33
+ Then we compute the area of intersect between box_a and box_b.
34
+ Args:
35
+ box_a: (tensor) bounding boxes, Shape: [A,4].
36
+ box_b: (tensor) bounding boxes, Shape: [B,4].
37
+ Return:
38
+ (tensor) intersection area, Shape: [A,B].
39
+ """
40
+ A = box_a.size(0)
41
+ B = box_b.size(0)
42
+ max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
43
+ box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
44
+ min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
45
+ box_b[:, :2].unsqueeze(0).expand(A, B, 2))
46
+ inter = torch.clamp((max_xy - min_xy), min=0)
47
+ return inter[:, :, 0] * inter[:, :, 1]
48
+
49
+
50
+ def jaccard(box_a, box_b):
51
+ """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
52
+ is simply the intersection over union of two boxes. Here we operate on
53
+ ground truth boxes and default boxes.
54
+ E.g.:
55
+ A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
56
+ Args:
57
+ box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
58
+ box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
59
+ Return:
60
+ jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
61
+ """
62
+ inter = intersect(box_a, box_b)
63
+ area_a = ((box_a[:, 2]-box_a[:, 0]) *
64
+ (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
65
+ area_b = ((box_b[:, 2]-box_b[:, 0]) *
66
+ (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
67
+ union = area_a + area_b - inter
68
+ return inter / union # [A,B]
69
+
70
+
71
+ def matrix_iou(a, b):
72
+ """
73
+ return iou of a and b, numpy version for data augenmentation
74
+ """
75
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
76
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
77
+
78
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
79
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
80
+ area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
81
+ return area_i / (area_a[:, np.newaxis] + area_b - area_i)
82
+
83
+
84
+ def matrix_iof(a, b):
85
+ """
86
+ return iof of a and b, numpy version for data augenmentation
87
+ """
88
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
89
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
90
+
91
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
92
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
93
+ return area_i / np.maximum(area_a[:, np.newaxis], 1)
94
+
95
+
96
+ def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx):
97
+ """Match each prior box with the ground truth box of the highest jaccard
98
+ overlap, encode the bounding boxes, then return the matched indices
99
+ corresponding to both confidence and location preds.
100
+ Args:
101
+ threshold: (float) The overlap threshold used when mathing boxes.
102
+ truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors].
103
+ priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
104
+ variances: (tensor) Variances corresponding to each prior coord,
105
+ Shape: [num_priors, 4].
106
+ labels: (tensor) All the class labels for the image, Shape: [num_obj].
107
+ loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
108
+ conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
109
+ idx: (int) current batch index
110
+ Return:
111
+ The matched indices corresponding to 1)location and 2)confidence preds.
112
+ """
113
+ # jaccard index
114
+ overlaps = jaccard(
115
+ truths,
116
+ point_form(priors)
117
+ )
118
+ # (Bipartite Matching)
119
+ # [1,num_objects] best prior for each ground truth
120
+ best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
121
+
122
+ # ignore hard gt
123
+ valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
124
+ best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
125
+ if best_prior_idx_filter.shape[0] <= 0:
126
+ loc_t[idx] = 0
127
+ conf_t[idx] = 0
128
+ return
129
+
130
+ # [1,num_priors] best ground truth for each prior
131
+ best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
132
+ best_truth_idx.squeeze_(0)
133
+ best_truth_overlap.squeeze_(0)
134
+ best_prior_idx.squeeze_(1)
135
+ best_prior_idx_filter.squeeze_(1)
136
+ best_prior_overlap.squeeze_(1)
137
+ best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
138
+ # TODO refactor: index best_prior_idx with long tensor
139
+ # ensure every gt matches with its prior of max overlap
140
+ for j in range(best_prior_idx.size(0)):
141
+ best_truth_idx[best_prior_idx[j]] = j
142
+ matches = truths[best_truth_idx] # Shape: [num_priors,4]
143
+ conf = labels[best_truth_idx] # Shape: [num_priors]
144
+ conf[best_truth_overlap < threshold] = 0 # label as background
145
+ loc = encode(matches, priors, variances)
146
+ loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
147
+ conf_t[idx] = conf # [num_priors] top class label for each prior
148
+
149
+
150
+ def encode(matched, priors, variances):
151
+ """Encode the variances from the priorbox layers into the ground truth boxes
152
+ we have matched (based on jaccard overlap) with the prior boxes.
153
+ Args:
154
+ matched: (tensor) Coords of ground truth for each prior in point-form
155
+ Shape: [num_priors, 4].
156
+ priors: (tensor) Prior boxes in center-offset form
157
+ Shape: [num_priors,4].
158
+ variances: (list[float]) Variances of priorboxes
159
+ Return:
160
+ encoded boxes (tensor), Shape: [num_priors, 4]
161
+ """
162
+
163
+ # dist b/t match center and prior's center
164
+ g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
165
+ # encode variance
166
+ g_cxcy /= (variances[0] * priors[:, 2:])
167
+ # match wh / prior wh
168
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
169
+ g_wh = torch.log(g_wh) / variances[1]
170
+ # return target for smooth_l1_loss
171
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
172
+
173
+
174
+ # Adapted from https://github.com/Hakuyume/chainer-ssd
175
+ def decode(loc, priors, variances):
176
+ """Decode locations from predictions using priors to undo
177
+ the encoding we did for offset regression at train time.
178
+ Args:
179
+ loc (tensor): location predictions for loc layers,
180
+ Shape: [num_priors,4]
181
+ priors (tensor): Prior boxes in center-offset form.
182
+ Shape: [num_priors,4].
183
+ variances: (list[float]) Variances of priorboxes
184
+ Return:
185
+ decoded bounding box predictions
186
+ """
187
+
188
+ boxes = torch.cat((
189
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
190
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
191
+ boxes[:, :2] -= boxes[:, 2:] / 2
192
+ boxes[:, 2:] += boxes[:, :2]
193
+ return boxes
194
+
195
+
196
+ def log_sum_exp(x):
197
+ """Utility function for computing log_sum_exp while determining
198
+ This will be used to determine unaveraged confidence loss across
199
+ all examples in a batch.
200
+ Args:
201
+ x (Variable(tensor)): conf_preds from conf layers
202
+ """
203
+ x_max = x.data.max()
204
+ return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max
205
+
206
+
207
+ # Original author: Francisco Massa:
208
+ # https://github.com/fmassa/object-detection.torch
209
+ # Ported to PyTorch by Max deGroot (02/01/2017)
210
+ def nms(boxes, scores, overlap=0.5, top_k=200):
211
+ """Apply non-maximum suppression at test time to avoid detecting too many
212
+ overlapping bounding boxes for a given object.
213
+ Args:
214
+ boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
215
+ scores: (tensor) The class predscores for the img, Shape:[num_priors].
216
+ overlap: (float) The overlap thresh for suppressing unnecessary boxes.
217
+ top_k: (int) The Maximum number of box preds to consider.
218
+ Return:
219
+ The indices of the kept boxes with respect to num_priors.
220
+ """
221
+
222
+ keep = torch.Tensor(scores.size(0)).fill_(0).long()
223
+ if boxes.numel() == 0:
224
+ return keep
225
+ x1 = boxes[:, 0]
226
+ y1 = boxes[:, 1]
227
+ x2 = boxes[:, 2]
228
+ y2 = boxes[:, 3]
229
+ area = torch.mul(x2 - x1, y2 - y1)
230
+ v, idx = scores.sort(0) # sort in ascending order
231
+ # I = I[v >= 0.01]
232
+ idx = idx[-top_k:] # indices of the top-k largest vals
233
+ xx1 = boxes.new()
234
+ yy1 = boxes.new()
235
+ xx2 = boxes.new()
236
+ yy2 = boxes.new()
237
+ w = boxes.new()
238
+ h = boxes.new()
239
+
240
+ # keep = torch.Tensor()
241
+ count = 0
242
+ while idx.numel() > 0:
243
+ i = idx[-1] # index of current largest val
244
+ # keep.append(i)
245
+ keep[count] = i
246
+ count += 1
247
+ if idx.size(0) == 1:
248
+ break
249
+ idx = idx[:-1] # remove kept element from view
250
+ # load bboxes of next highest vals
251
+ torch.index_select(x1, 0, idx, out=xx1)
252
+ torch.index_select(y1, 0, idx, out=yy1)
253
+ torch.index_select(x2, 0, idx, out=xx2)
254
+ torch.index_select(y2, 0, idx, out=yy2)
255
+ # store element-wise max with next highest score
256
+ xx1 = torch.clamp(xx1, min=x1[i])
257
+ yy1 = torch.clamp(yy1, min=y1[i])
258
+ xx2 = torch.clamp(xx2, max=x2[i])
259
+ yy2 = torch.clamp(yy2, max=y2[i])
260
+ w.resize_as_(xx2)
261
+ h.resize_as_(yy2)
262
+ w = xx2 - xx1
263
+ h = yy2 - yy1
264
+ # check sizes of xx1 and xx2.. after each iteration
265
+ w = torch.clamp(w, min=0.0)
266
+ h = torch.clamp(h, min=0.0)
267
+ inter = w*h
268
+ # IoU = i / (area(a) + area(b) - i)
269
+ rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
270
+ union = (rem_areas - inter) + area[i]
271
+ IoU = inter/union # store result in iou
272
+ # keep only elements with an IoU <= overlap
273
+ idx = idx[IoU.le(overlap)]
274
+ return keep, count
275
+
276
+
external/landmark_detection/FaceBoxesV2/utils/build.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ # --------------------------------------------------------
4
+ # Fast R-CNN
5
+ # Copyright (c) 2015 Microsoft
6
+ # Licensed under The MIT License [see LICENSE for details]
7
+ # Written by Ross Girshick
8
+ # --------------------------------------------------------
9
+
10
+ import os
11
+ from os.path import join as pjoin
12
+ import numpy as np
13
+ from distutils.core import setup
14
+ from distutils.extension import Extension
15
+ from Cython.Distutils import build_ext
16
+
17
+
18
+ def find_in_path(name, path):
19
+ "Find a file in a search path"
20
+ # adapted fom http://code.activestate.com/recipes/52224-find-a-file-given-a-search-path/
21
+ for dir in path.split(os.pathsep):
22
+ binpath = pjoin(dir, name)
23
+ if os.path.exists(binpath):
24
+ return os.path.abspath(binpath)
25
+ return None
26
+
27
+
28
+ # Obtain the numpy include directory. This logic works across numpy versions.
29
+ try:
30
+ numpy_include = np.get_include()
31
+ except AttributeError:
32
+ numpy_include = np.get_numpy_include()
33
+
34
+
35
+ # run the customize_compiler
36
+ class custom_build_ext(build_ext):
37
+ def build_extensions(self):
38
+ # customize_compiler_for_nvcc(self.compiler)
39
+ build_ext.build_extensions(self)
40
+
41
+
42
+ ext_modules = [
43
+ Extension(
44
+ "nms.cpu_nms",
45
+ ["nms/cpu_nms.pyx"],
46
+ # extra_compile_args={'gcc': ["-Wno-cpp", "-Wno-unused-function"]},
47
+ extra_compile_args=["-Wno-cpp", "-Wno-unused-function"],
48
+ include_dirs=[numpy_include]
49
+ )
50
+ ]
51
+
52
+ setup(
53
+ name='mot_utils',
54
+ ext_modules=ext_modules,
55
+ # inject our custom trigger
56
+ cmdclass={'build_ext': custom_build_ext},
57
+ )
external/landmark_detection/FaceBoxesV2/utils/config.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+
3
+ cfg = {
4
+ 'name': 'FaceBoxes',
5
+ #'min_dim': 1024,
6
+ #'feature_maps': [[32, 32], [16, 16], [8, 8]],
7
+ # 'aspect_ratios': [[1], [1], [1]],
8
+ 'min_sizes': [[32, 64, 128], [256], [512]],
9
+ 'steps': [32, 64, 128],
10
+ 'variance': [0.1, 0.2],
11
+ 'clip': False,
12
+ 'loc_weight': 2.0,
13
+ 'gpu_train': True
14
+ }
external/landmark_detection/FaceBoxesV2/utils/faceboxes.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class BasicConv2d(nn.Module):
7
+
8
+ def __init__(self, in_channels, out_channels, **kwargs):
9
+ super(BasicConv2d, self).__init__()
10
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
11
+ self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
12
+
13
+ def forward(self, x):
14
+ x = self.conv(x)
15
+ x = self.bn(x)
16
+ return F.relu(x, inplace=True)
17
+
18
+
19
+ class Inception(nn.Module):
20
+
21
+ def __init__(self):
22
+ super(Inception, self).__init__()
23
+ self.branch1x1 = BasicConv2d(128, 32, kernel_size=1, padding=0)
24
+ self.branch1x1_2 = BasicConv2d(128, 32, kernel_size=1, padding=0)
25
+ self.branch3x3_reduce = BasicConv2d(128, 24, kernel_size=1, padding=0)
26
+ self.branch3x3 = BasicConv2d(24, 32, kernel_size=3, padding=1)
27
+ self.branch3x3_reduce_2 = BasicConv2d(128, 24, kernel_size=1, padding=0)
28
+ self.branch3x3_2 = BasicConv2d(24, 32, kernel_size=3, padding=1)
29
+ self.branch3x3_3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
30
+
31
+ def forward(self, x):
32
+ branch1x1 = self.branch1x1(x)
33
+
34
+ branch1x1_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
35
+ branch1x1_2 = self.branch1x1_2(branch1x1_pool)
36
+
37
+ branch3x3_reduce = self.branch3x3_reduce(x)
38
+ branch3x3 = self.branch3x3(branch3x3_reduce)
39
+
40
+ branch3x3_reduce_2 = self.branch3x3_reduce_2(x)
41
+ branch3x3_2 = self.branch3x3_2(branch3x3_reduce_2)
42
+ branch3x3_3 = self.branch3x3_3(branch3x3_2)
43
+
44
+ outputs = [branch1x1, branch1x1_2, branch3x3, branch3x3_3]
45
+ return torch.cat(outputs, 1)
46
+
47
+
48
+ class CRelu(nn.Module):
49
+
50
+ def __init__(self, in_channels, out_channels, **kwargs):
51
+ super(CRelu, self).__init__()
52
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
53
+ self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
54
+
55
+ def forward(self, x):
56
+ x = self.conv(x)
57
+ x = self.bn(x)
58
+ x = torch.cat([x, -x], 1)
59
+ x = F.relu(x, inplace=True)
60
+ return x
61
+
62
+
63
+ class FaceBoxes(nn.Module):
64
+
65
+ def __init__(self, phase, size, num_classes):
66
+ super(FaceBoxes, self).__init__()
67
+ self.phase = phase
68
+ self.num_classes = num_classes
69
+ self.size = size
70
+
71
+ self.conv1 = CRelu(3, 24, kernel_size=7, stride=4, padding=3)
72
+ self.conv2 = CRelu(48, 64, kernel_size=5, stride=2, padding=2)
73
+
74
+ self.inception1 = Inception()
75
+ self.inception2 = Inception()
76
+ self.inception3 = Inception()
77
+
78
+ self.conv3_1 = BasicConv2d(128, 128, kernel_size=1, stride=1, padding=0)
79
+ self.conv3_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
80
+
81
+ self.conv4_1 = BasicConv2d(256, 128, kernel_size=1, stride=1, padding=0)
82
+ self.conv4_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
83
+
84
+ self.loc, self.conf = self.multibox(self.num_classes)
85
+
86
+ if self.phase == 'test':
87
+ self.softmax = nn.Softmax(dim=-1)
88
+
89
+ if self.phase == 'train':
90
+ for m in self.modules():
91
+ if isinstance(m, nn.Conv2d):
92
+ if m.bias is not None:
93
+ nn.init.xavier_normal_(m.weight.data)
94
+ m.bias.data.fill_(0.02)
95
+ else:
96
+ m.weight.data.normal_(0, 0.01)
97
+ elif isinstance(m, nn.BatchNorm2d):
98
+ m.weight.data.fill_(1)
99
+ m.bias.data.zero_()
100
+
101
+ def multibox(self, num_classes):
102
+ loc_layers = []
103
+ conf_layers = []
104
+ loc_layers += [nn.Conv2d(128, 21 * 4, kernel_size=3, padding=1)]
105
+ conf_layers += [nn.Conv2d(128, 21 * num_classes, kernel_size=3, padding=1)]
106
+ loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]
107
+ conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]
108
+ loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]
109
+ conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]
110
+ return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers)
111
+
112
+ def forward(self, x):
113
+
114
+ detection_sources = list()
115
+ loc = list()
116
+ conf = list()
117
+
118
+ x = self.conv1(x)
119
+ x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
120
+ x = self.conv2(x)
121
+ x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
122
+ x = self.inception1(x)
123
+ x = self.inception2(x)
124
+ x = self.inception3(x)
125
+ detection_sources.append(x)
126
+
127
+ x = self.conv3_1(x)
128
+ x = self.conv3_2(x)
129
+ detection_sources.append(x)
130
+
131
+ x = self.conv4_1(x)
132
+ x = self.conv4_2(x)
133
+ detection_sources.append(x)
134
+
135
+ for (x, l, c) in zip(detection_sources, self.loc, self.conf):
136
+ loc.append(l(x).permute(0, 2, 3, 1).contiguous())
137
+ conf.append(c(x).permute(0, 2, 3, 1).contiguous())
138
+
139
+ loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
140
+ conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
141
+
142
+ if self.phase == "test":
143
+ output = (loc.view(loc.size(0), -1, 4),
144
+ self.softmax(conf.view(conf.size(0), -1, self.num_classes)))
145
+ else:
146
+ output = (loc.view(loc.size(0), -1, 4),
147
+ conf.view(conf.size(0), -1, self.num_classes))
148
+
149
+ return output
150
+
151
+ class FaceBoxesV2(nn.Module):
152
+
153
+ def __init__(self, phase, size, num_classes):
154
+ super(FaceBoxesV2, self).__init__()
155
+ self.phase = phase
156
+ self.num_classes = num_classes
157
+ self.size = size
158
+
159
+ self.conv1 = BasicConv2d(3, 8, kernel_size=3, stride=2, padding=1)
160
+ self.conv2 = BasicConv2d(8, 16, kernel_size=3, stride=2, padding=1)
161
+ self.conv3 = BasicConv2d(16, 32, kernel_size=3, stride=2, padding=1)
162
+ self.conv4 = BasicConv2d(32, 64, kernel_size=3, stride=2, padding=1)
163
+ self.conv5 = BasicConv2d(64, 128, kernel_size=3, stride=2, padding=1)
164
+
165
+ self.inception1 = Inception()
166
+ self.inception2 = Inception()
167
+ self.inception3 = Inception()
168
+
169
+ self.conv6_1 = BasicConv2d(128, 128, kernel_size=1, stride=1, padding=0)
170
+ self.conv6_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
171
+
172
+ self.conv7_1 = BasicConv2d(256, 128, kernel_size=1, stride=1, padding=0)
173
+ self.conv7_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
174
+
175
+ self.loc, self.conf = self.multibox(self.num_classes)
176
+
177
+ if self.phase == 'test':
178
+ self.softmax = nn.Softmax(dim=-1)
179
+
180
+ if self.phase == 'train':
181
+ for m in self.modules():
182
+ if isinstance(m, nn.Conv2d):
183
+ if m.bias is not None:
184
+ nn.init.xavier_normal_(m.weight.data)
185
+ m.bias.data.fill_(0.02)
186
+ else:
187
+ m.weight.data.normal_(0, 0.01)
188
+ elif isinstance(m, nn.BatchNorm2d):
189
+ m.weight.data.fill_(1)
190
+ m.bias.data.zero_()
191
+
192
+ def multibox(self, num_classes):
193
+ loc_layers = []
194
+ conf_layers = []
195
+ loc_layers += [nn.Conv2d(128, 21 * 4, kernel_size=3, padding=1)]
196
+ conf_layers += [nn.Conv2d(128, 21 * num_classes, kernel_size=3, padding=1)]
197
+ loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]
198
+ conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]
199
+ loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]
200
+ conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]
201
+ return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers)
202
+
203
+ def forward(self, x):
204
+
205
+ sources = list()
206
+ loc = list()
207
+ conf = list()
208
+
209
+ x = self.conv1(x)
210
+ x = self.conv2(x)
211
+ x = self.conv3(x)
212
+ x = self.conv4(x)
213
+ x = self.conv5(x)
214
+ x = self.inception1(x)
215
+ x = self.inception2(x)
216
+ x = self.inception3(x)
217
+ sources.append(x)
218
+ x = self.conv6_1(x)
219
+ x = self.conv6_2(x)
220
+ sources.append(x)
221
+ x = self.conv7_1(x)
222
+ x = self.conv7_2(x)
223
+ sources.append(x)
224
+
225
+ for (x, l, c) in zip(sources, self.loc, self.conf):
226
+ loc.append(l(x).permute(0, 2, 3, 1).contiguous())
227
+ conf.append(c(x).permute(0, 2, 3, 1).contiguous())
228
+
229
+ loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
230
+ conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
231
+
232
+ if self.phase == "test":
233
+ output = (loc.view(loc.size(0), -1, 4),
234
+ self.softmax(conf.view(-1, self.num_classes)))
235
+ else:
236
+ output = (loc.view(loc.size(0), -1, 4),
237
+ conf.view(conf.size(0), -1, self.num_classes))
238
+
239
+ return output
external/landmark_detection/FaceBoxesV2/utils/make.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ python3 build.py build_ext --inplace
3
+
external/landmark_detection/FaceBoxesV2/utils/nms/__init__.py ADDED
File without changes
external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.c ADDED
The diff for this file is too large to render. See raw diff
 
external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.py ADDED
File without changes
external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.pyx ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Fast R-CNN
3
+ # Copyright (c) 2015 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ross Girshick
6
+ # --------------------------------------------------------
7
+
8
+ import numpy as np
9
+ cimport numpy as np
10
+
11
+ cdef inline np.float32_t max(np.float32_t a, np.float32_t b):
12
+ return a if a >= b else b
13
+
14
+ cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
15
+ return a if a <= b else b
16
+
17
+ def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh):
18
+ cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
19
+ cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
20
+ cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
21
+ cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]
22
+ cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4]
23
+
24
+ cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1)
25
+ cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1]
26
+
27
+ cdef int ndets = dets.shape[0]
28
+ cdef np.ndarray[np.int_t, ndim=1] suppressed = \
29
+ np.zeros((ndets), dtype=np.int)
30
+
31
+ # nominal indices
32
+ cdef int _i, _j
33
+ # sorted indices
34
+ cdef int i, j
35
+ # temp variables for box i's (the box currently under consideration)
36
+ cdef np.float32_t ix1, iy1, ix2, iy2, iarea
37
+ # variables for computing overlap with box j (lower scoring box)
38
+ cdef np.float32_t xx1, yy1, xx2, yy2
39
+ cdef np.float32_t w, h
40
+ cdef np.float32_t inter, ovr
41
+
42
+ keep = []
43
+ for _i in range(ndets):
44
+ i = order[_i]
45
+ if suppressed[i] == 1:
46
+ continue
47
+ keep.append(i)
48
+ ix1 = x1[i]
49
+ iy1 = y1[i]
50
+ ix2 = x2[i]
51
+ iy2 = y2[i]
52
+ iarea = areas[i]
53
+ for _j in range(_i + 1, ndets):
54
+ j = order[_j]
55
+ if suppressed[j] == 1:
56
+ continue
57
+ xx1 = max(ix1, x1[j])
58
+ yy1 = max(iy1, y1[j])
59
+ xx2 = min(ix2, x2[j])
60
+ yy2 = min(iy2, y2[j])
61
+ w = max(0.0, xx2 - xx1 + 1)
62
+ h = max(0.0, yy2 - yy1 + 1)
63
+ inter = w * h
64
+ ovr = inter / (iarea + areas[j] - inter)
65
+ if ovr >= thresh:
66
+ suppressed[j] = 1
67
+
68
+ return keep
69
+
70
+ def cpu_soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0):
71
+ cdef unsigned int N = boxes.shape[0]
72
+ cdef float iw, ih, box_area
73
+ cdef float ua
74
+ cdef int pos = 0
75
+ cdef float maxscore = 0
76
+ cdef int maxpos = 0
77
+ cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov
78
+
79
+ for i in range(N):
80
+ maxscore = boxes[i, 4]
81
+ maxpos = i
82
+
83
+ tx1 = boxes[i,0]
84
+ ty1 = boxes[i,1]
85
+ tx2 = boxes[i,2]
86
+ ty2 = boxes[i,3]
87
+ ts = boxes[i,4]
88
+
89
+ pos = i + 1
90
+ # get max box
91
+ while pos < N:
92
+ if maxscore < boxes[pos, 4]:
93
+ maxscore = boxes[pos, 4]
94
+ maxpos = pos
95
+ pos = pos + 1
96
+
97
+ # add max box as a detection
98
+ boxes[i,0] = boxes[maxpos,0]
99
+ boxes[i,1] = boxes[maxpos,1]
100
+ boxes[i,2] = boxes[maxpos,2]
101
+ boxes[i,3] = boxes[maxpos,3]
102
+ boxes[i,4] = boxes[maxpos,4]
103
+
104
+ # swap ith box with position of max box
105
+ boxes[maxpos,0] = tx1
106
+ boxes[maxpos,1] = ty1
107
+ boxes[maxpos,2] = tx2
108
+ boxes[maxpos,3] = ty2
109
+ boxes[maxpos,4] = ts
110
+
111
+ tx1 = boxes[i,0]
112
+ ty1 = boxes[i,1]
113
+ tx2 = boxes[i,2]
114
+ ty2 = boxes[i,3]
115
+ ts = boxes[i,4]
116
+
117
+ pos = i + 1
118
+ # NMS iterations, note that N changes if detection boxes fall below threshold
119
+ while pos < N:
120
+ x1 = boxes[pos, 0]
121
+ y1 = boxes[pos, 1]
122
+ x2 = boxes[pos, 2]
123
+ y2 = boxes[pos, 3]
124
+ s = boxes[pos, 4]
125
+
126
+ area = (x2 - x1 + 1) * (y2 - y1 + 1)
127
+ iw = (min(tx2, x2) - max(tx1, x1) + 1)
128
+ if iw > 0:
129
+ ih = (min(ty2, y2) - max(ty1, y1) + 1)
130
+ if ih > 0:
131
+ ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih)
132
+ ov = iw * ih / ua #iou between max box and detection box
133
+
134
+ if method == 1: # linear
135
+ if ov > Nt:
136
+ weight = 1 - ov
137
+ else:
138
+ weight = 1
139
+ elif method == 2: # gaussian
140
+ weight = np.exp(-(ov * ov)/sigma)
141
+ else: # original NMS
142
+ if ov > Nt:
143
+ weight = 0
144
+ else:
145
+ weight = 1
146
+
147
+ boxes[pos, 4] = weight*boxes[pos, 4]
148
+
149
+ # if box score falls below threshold, discard the box by swapping with last box
150
+ # update N
151
+ if boxes[pos, 4] < threshold:
152
+ boxes[pos,0] = boxes[N-1, 0]
153
+ boxes[pos,1] = boxes[N-1, 1]
154
+ boxes[pos,2] = boxes[N-1, 2]
155
+ boxes[pos,3] = boxes[N-1, 3]
156
+ boxes[pos,4] = boxes[N-1, 4]
157
+ N = N - 1
158
+ pos = pos - 1
159
+
160
+ pos = pos + 1
161
+
162
+ keep = [i for i in range(N)]
163
+ return keep
external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.hpp ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num,
2
+ int boxes_dim, float nms_overlap_thresh, int device_id);
external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.pyx ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Faster R-CNN
3
+ # Copyright (c) 2015 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ross Girshick
6
+ # --------------------------------------------------------
7
+
8
+ import numpy as np
9
+ cimport numpy as np
10
+
11
+ assert sizeof(int) == sizeof(np.int32_t)
12
+
13
+ cdef extern from "gpu_nms.hpp":
14
+ void _nms(np.int32_t*, int*, np.float32_t*, int, int, float, int)
15
+
16
+ def gpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh,
17
+ np.int32_t device_id=0):
18
+ cdef int boxes_num = dets.shape[0]
19
+ cdef int boxes_dim = dets.shape[1]
20
+ cdef int num_out
21
+ cdef np.ndarray[np.int32_t, ndim=1] \
22
+ keep = np.zeros(boxes_num, dtype=np.int32)
23
+ cdef np.ndarray[np.float32_t, ndim=1] \
24
+ scores = dets[:, 4]
25
+ cdef np.ndarray[np.int_t, ndim=1] \
26
+ order = scores.argsort()[::-1]
27
+ cdef np.ndarray[np.float32_t, ndim=2] \
28
+ sorted_dets = dets[order, :]
29
+ _nms(&keep[0], &num_out, &sorted_dets[0, 0], boxes_num, boxes_dim, thresh, device_id)
30
+ keep = keep[:num_out]
31
+ return list(order[keep])
external/landmark_detection/FaceBoxesV2/utils/nms/nms_kernel.cu ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // ------------------------------------------------------------------
2
+ // Faster R-CNN
3
+ // Copyright (c) 2015 Microsoft
4
+ // Licensed under The MIT License [see fast-rcnn/LICENSE for details]
5
+ // Written by Shaoqing Ren
6
+ // ------------------------------------------------------------------
7
+
8
+ #include "gpu_nms.hpp"
9
+ #include <vector>
10
+ #include <iostream>
11
+
12
+ #define CUDA_CHECK(condition) \
13
+ /* Code block avoids redefinition of cudaError_t error */ \
14
+ do { \
15
+ cudaError_t error = condition; \
16
+ if (error != cudaSuccess) { \
17
+ std::cout << cudaGetErrorString(error) << std::endl; \
18
+ } \
19
+ } while (0)
20
+
21
+ #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
22
+ int const threadsPerBlock = sizeof(unsigned long long) * 8;
23
+
24
+ __device__ inline float devIoU(float const * const a, float const * const b) {
25
+ float left = max(a[0], b[0]), right = min(a[2], b[2]);
26
+ float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
27
+ float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
28
+ float interS = width * height;
29
+ float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
30
+ float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
31
+ return interS / (Sa + Sb - interS);
32
+ }
33
+
34
+ __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
35
+ const float *dev_boxes, unsigned long long *dev_mask) {
36
+ const int row_start = blockIdx.y;
37
+ const int col_start = blockIdx.x;
38
+
39
+ // if (row_start > col_start) return;
40
+
41
+ const int row_size =
42
+ min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
43
+ const int col_size =
44
+ min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
45
+
46
+ __shared__ float block_boxes[threadsPerBlock * 5];
47
+ if (threadIdx.x < col_size) {
48
+ block_boxes[threadIdx.x * 5 + 0] =
49
+ dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
50
+ block_boxes[threadIdx.x * 5 + 1] =
51
+ dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
52
+ block_boxes[threadIdx.x * 5 + 2] =
53
+ dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
54
+ block_boxes[threadIdx.x * 5 + 3] =
55
+ dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
56
+ block_boxes[threadIdx.x * 5 + 4] =
57
+ dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
58
+ }
59
+ __syncthreads();
60
+
61
+ if (threadIdx.x < row_size) {
62
+ const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
63
+ const float *cur_box = dev_boxes + cur_box_idx * 5;
64
+ int i = 0;
65
+ unsigned long long t = 0;
66
+ int start = 0;
67
+ if (row_start == col_start) {
68
+ start = threadIdx.x + 1;
69
+ }
70
+ for (i = start; i < col_size; i++) {
71
+ if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
72
+ t |= 1ULL << i;
73
+ }
74
+ }
75
+ const int col_blocks = DIVUP(n_boxes, threadsPerBlock);
76
+ dev_mask[cur_box_idx * col_blocks + col_start] = t;
77
+ }
78
+ }
79
+
80
+ void _set_device(int device_id) {
81
+ int current_device;
82
+ CUDA_CHECK(cudaGetDevice(&current_device));
83
+ if (current_device == device_id) {
84
+ return;
85
+ }
86
+ // The call to cudaSetDevice must come before any calls to Get, which
87
+ // may perform initialization using the GPU.
88
+ CUDA_CHECK(cudaSetDevice(device_id));
89
+ }
90
+
91
+ void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num,
92
+ int boxes_dim, float nms_overlap_thresh, int device_id) {
93
+ _set_device(device_id);
94
+
95
+ float* boxes_dev = NULL;
96
+ unsigned long long* mask_dev = NULL;
97
+
98
+ const int col_blocks = DIVUP(boxes_num, threadsPerBlock);
99
+
100
+ CUDA_CHECK(cudaMalloc(&boxes_dev,
101
+ boxes_num * boxes_dim * sizeof(float)));
102
+ CUDA_CHECK(cudaMemcpy(boxes_dev,
103
+ boxes_host,
104
+ boxes_num * boxes_dim * sizeof(float),
105
+ cudaMemcpyHostToDevice));
106
+
107
+ CUDA_CHECK(cudaMalloc(&mask_dev,
108
+ boxes_num * col_blocks * sizeof(unsigned long long)));
109
+
110
+ dim3 blocks(DIVUP(boxes_num, threadsPerBlock),
111
+ DIVUP(boxes_num, threadsPerBlock));
112
+ dim3 threads(threadsPerBlock);
113
+ nms_kernel<<<blocks, threads>>>(boxes_num,
114
+ nms_overlap_thresh,
115
+ boxes_dev,
116
+ mask_dev);
117
+
118
+ std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
119
+ CUDA_CHECK(cudaMemcpy(&mask_host[0],
120
+ mask_dev,
121
+ sizeof(unsigned long long) * boxes_num * col_blocks,
122
+ cudaMemcpyDeviceToHost));
123
+
124
+ std::vector<unsigned long long> remv(col_blocks);
125
+ memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
126
+
127
+ int num_to_keep = 0;
128
+ for (int i = 0; i < boxes_num; i++) {
129
+ int nblock = i / threadsPerBlock;
130
+ int inblock = i % threadsPerBlock;
131
+
132
+ if (!(remv[nblock] & (1ULL << inblock))) {
133
+ keep_out[num_to_keep++] = i;
134
+ unsigned long long *p = &mask_host[0] + i * col_blocks;
135
+ for (int j = nblock; j < col_blocks; j++) {
136
+ remv[j] |= p[j];
137
+ }
138
+ }
139
+ }
140
+ *num_out = num_to_keep;
141
+
142
+ CUDA_CHECK(cudaFree(boxes_dev));
143
+ CUDA_CHECK(cudaFree(mask_dev));
144
+ }
external/landmark_detection/FaceBoxesV2/utils/nms/py_cpu_nms.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Fast R-CNN
3
+ # Copyright (c) 2015 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ross Girshick
6
+ # --------------------------------------------------------
7
+
8
+ import numpy as np
9
+
10
+ def py_cpu_nms(dets, thresh):
11
+ """Pure Python NMS baseline."""
12
+ x1 = dets[:, 0]
13
+ y1 = dets[:, 1]
14
+ x2 = dets[:, 2]
15
+ y2 = dets[:, 3]
16
+ scores = dets[:, 4]
17
+
18
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
19
+ order = scores.argsort()[::-1]
20
+
21
+ keep = []
22
+ while order.size > 0:
23
+ i = order[0]
24
+ keep.append(i)
25
+ xx1 = np.maximum(x1[i], x1[order[1:]])
26
+ yy1 = np.maximum(y1[i], y1[order[1:]])
27
+ xx2 = np.minimum(x2[i], x2[order[1:]])
28
+ yy2 = np.minimum(y2[i], y2[order[1:]])
29
+
30
+ w = np.maximum(0.0, xx2 - xx1 + 1)
31
+ h = np.maximum(0.0, yy2 - yy1 + 1)
32
+ inter = w * h
33
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
34
+
35
+ inds = np.where(ovr <= thresh)[0]
36
+ order = order[inds + 1]
37
+
38
+ return keep
external/landmark_detection/FaceBoxesV2/utils/nms_wrapper.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Fast R-CNN
3
+ # Copyright (c) 2015 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ross Girshick
6
+ # --------------------------------------------------------
7
+
8
+ from .nms.cpu_nms import cpu_nms, cpu_soft_nms
9
+
10
+ def nms(dets, thresh):
11
+ """Dispatch to either CPU or GPU NMS implementations."""
12
+
13
+ if dets.shape[0] == 0:
14
+ return []
15
+ return cpu_nms(dets, thresh)
external/landmark_detection/FaceBoxesV2/utils/prior_box.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from itertools import product as product
3
+ import numpy as np
4
+ from math import ceil
5
+
6
+
7
+ class PriorBox(object):
8
+ def __init__(self, cfg, image_size=None, phase='train'):
9
+ super(PriorBox, self).__init__()
10
+ #self.aspect_ratios = cfg['aspect_ratios']
11
+ self.min_sizes = cfg['min_sizes']
12
+ self.steps = cfg['steps']
13
+ self.clip = cfg['clip']
14
+ self.image_size = image_size
15
+ self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
16
+
17
+ def forward(self):
18
+ anchors = []
19
+ for k, f in enumerate(self.feature_maps):
20
+ min_sizes = self.min_sizes[k]
21
+ for i, j in product(range(f[0]), range(f[1])):
22
+ for min_size in min_sizes:
23
+ s_kx = min_size / self.image_size[1]
24
+ s_ky = min_size / self.image_size[0]
25
+ if min_size == 32:
26
+ dense_cx = [x*self.steps[k]/self.image_size[1] for x in [j+0, j+0.25, j+0.5, j+0.75]]
27
+ dense_cy = [y*self.steps[k]/self.image_size[0] for y in [i+0, i+0.25, i+0.5, i+0.75]]
28
+ for cy, cx in product(dense_cy, dense_cx):
29
+ anchors += [cx, cy, s_kx, s_ky]
30
+ elif min_size == 64:
31
+ dense_cx = [x*self.steps[k]/self.image_size[1] for x in [j+0, j+0.5]]
32
+ dense_cy = [y*self.steps[k]/self.image_size[0] for y in [i+0, i+0.5]]
33
+ for cy, cx in product(dense_cy, dense_cx):
34
+ anchors += [cx, cy, s_kx, s_ky]
35
+ else:
36
+ cx = (j + 0.5) * self.steps[k] / self.image_size[1]
37
+ cy = (i + 0.5) * self.steps[k] / self.image_size[0]
38
+ anchors += [cx, cy, s_kx, s_ky]
39
+ # back to torch land
40
+ output = torch.Tensor(anchors).view(-1, 4)
41
+ if self.clip:
42
+ output.clamp_(max=1, min=0)
43
+ return output
external/landmark_detection/FaceBoxesV2/utils/timer.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Fast R-CNN
3
+ # Copyright (c) 2015 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ross Girshick
6
+ # --------------------------------------------------------
7
+
8
+ import time
9
+
10
+
11
+ class Timer(object):
12
+ """A simple timer."""
13
+ def __init__(self):
14
+ self.total_time = 0.
15
+ self.calls = 0
16
+ self.start_time = 0.
17
+ self.diff = 0.
18
+ self.average_time = 0.
19
+
20
+ def tic(self):
21
+ # using time.time instead of time.clock because time time.clock
22
+ # does not normalize for multithreading
23
+ self.start_time = time.time()
24
+
25
+ def toc(self, average=True):
26
+ self.diff = time.time() - self.start_time
27
+ self.total_time += self.diff
28
+ self.calls += 1
29
+ self.average_time = self.total_time / self.calls
30
+ if average:
31
+ return self.average_time
32
+ else:
33
+ return self.diff
34
+
35
+ def clear(self):
36
+ self.total_time = 0.
37
+ self.calls = 0
38
+ self.start_time = 0.
39
+ self.diff = 0.
40
+ self.average_time = 0.
external/landmark_detection/README.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # STAR Loss: Reducing Semantic Ambiguity in Facial Landmark Detection.
2
+
3
+ Paper Link: [arxiv](https://arxiv.org/abs/2306.02763) | [CVPR 2023](https://openaccess.thecvf.com/content/CVPR2023/papers/Zhou_STAR_Loss_Reducing_Semantic_Ambiguity_in_Facial_Landmark_Detection_CVPR_2023_paper.pdf)
4
+
5
+
6
+ - Pytorch implementation of **S**elf-adap**T**ive **A**mbiguity **R**eduction (**STAR**) loss.
7
+ - STAR loss is a self-adaptive anisotropic direction loss, which can be used in heatmap regression-based methods for facial landmark detection.
8
+ - Specifically, we find that semantic ambiguity results in the anisotropic predicted distribution, which inspires us to use predicted distribution to represent semantic ambiguity. So, we use PCA to indicate the character of the predicted distribution and indirectly formulate the direction and intensity of semantic ambiguity. Based on this, STAR loss adaptively suppresses the prediction error in the ambiguity direction to mitigate the impact of ambiguity annotation in training. More details can be found in our paper.
9
+ <p align="center">
10
+ <img src="./images/framework.png" width="80%">
11
+ </p>
12
+
13
+
14
+
15
+
16
+ ## Dependencies
17
+
18
+ * python==3.7.3
19
+ * PyTorch=1.6.0
20
+ * requirements.txt
21
+
22
+ ## Dataset Preparation
23
+
24
+ - Step1: Download the raw images from [COFW](http://www.vision.caltech.edu/xpburgos/ICCV13/#dataset), [300W](https://ibug.doc.ic.ac.uk/resources/300-W/), and [WFLW](https://wywu.github.io/projects/LAB/WFLW.html).
25
+ - Step2: We follow the data preprocess in [ADNet](https://openaccess.thecvf.com/content/ICCV2021/papers/Huang_ADNet_Leveraging_Error-Bias_Towards_Normal_Direction_in_Face_Alignment_ICCV_2021_paper.pdf), and the metadata can be download from [the corresponding repository](https://github.com/huangyangyu/ADNet).
26
+ - Step3: Make them look like this:
27
+ ```script
28
+ # the dataset directory:
29
+ |-- ${image_dir}
30
+ |-- WFLW
31
+ | -- WFLW_images
32
+ |-- 300W
33
+ | -- afw
34
+ | -- helen
35
+ | -- ibug
36
+ | -- lfpw
37
+ |-- COFW
38
+ | -- train
39
+ | -- test
40
+ |-- ${annot_dir}
41
+ |-- WFLW
42
+ |-- train.tsv, test.tsv
43
+ |-- 300W
44
+ |-- train.tsv, test.tsv
45
+ |--COFW
46
+ |-- train.tsv, test.tsv
47
+ ```
48
+
49
+ ## Usage
50
+ * Work directory: set the ${ckpt_dir} in ./conf/alignment.py.
51
+ * Pretrained model:
52
+
53
+ | Dataset | Model |
54
+ |:-----------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------|
55
+ | WFLW | [google](https://drive.google.com/file/d/1aOx0wYEZUfBndYy_8IYszLPG_D2fhxrT/view?usp=sharing) / [baidu](https://pan.baidu.com/s/10vvI-ovs3x9NrdmpnXK6sg?pwd=u0yu) |
56
+ | 300W | [google](https://drive.google.com/file/d/1Fiu3hjjkQRdKsWE9IgyNPdiJSz9_MzA5/view?usp=sharing) / [baidu](https://pan.baidu.com/s/1bjUhLq1zS1XSl1nX78fU7A?pwd=yb2s) |
57
+ | COFW | [google](https://drive.google.com/file/d/1NFcZ9jzql_jnn3ulaSzUlyhS05HWB9n_/view?usp=drive_link) / [baidu](https://pan.baidu.com/s/1XO6hDZ8siJLTgFcpyu1Tzw?pwd=m57n) |
58
+
59
+
60
+ ### Training
61
+ ```shell
62
+ python main.py --mode=train --device_ids=0,1,2,3 \
63
+ --image_dir=${image_dir} --annot_dir=${annot_dir} \
64
+ --data_definition={WFLW, 300W, COFW}
65
+ ```
66
+
67
+ ### Testing
68
+ ```shell
69
+ python main.py --mode=test --device_ids=0 \
70
+ --image_dir=${image_dir} --annot_dir=${annot_dir} \
71
+ --data_definition={WFLW, 300W, COFW} \
72
+ --pretrained_weight=${model_path} \
73
+ ```
74
+
75
+ ### Evaluation
76
+ ```shell
77
+ python evaluate.py --device_ids=0 \
78
+ --model_path=${model_path} --metadata_path=${metadata_path} \
79
+ --image_dir=${image_dir} --data_definition={WFLW, 300W, COFW} \
80
+ ```
81
+
82
+ To test on your own image, the following code could be considered:
83
+ ```shell
84
+ python demo.py
85
+ ```
86
+
87
+
88
+ ## Results
89
+ The models trained by STAR Loss achieved **SOTA** performance in all of COFW, 300W and WFLW datasets.
90
+
91
+ <p align="center">
92
+ <img src="./images/results.png" width="80%">
93
+ </p>
94
+
95
+ ## BibTeX Citation
96
+ Please consider citing our papers in your publications if the project helps your research. BibTeX reference is as follows.
97
+ ```
98
+ @inproceedings{Zhou_2023_CVPR,
99
+ author = {Zhou, Zhenglin and Li, Huaxia and Liu, Hong and Wang, Nanyang and Yu, Gang and Ji, Rongrong},
100
+ title = {STAR Loss: Reducing Semantic Ambiguity in Facial Landmark Detection},
101
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
102
+ month = {June},
103
+ year = {2023},
104
+ pages = {15475-15484}
105
+ }
106
+ ```
107
+
108
+ ## Acknowledgments
109
+ This repository is built on top of [ADNet](https://github.com/huangyangyu/ADNet).
110
+ Thanks for this strong baseline.
external/landmark_detection/conf/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .alignment import Alignment
external/landmark_detection/conf/alignment.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ from .base import Base
3
+
4
+
5
+ class Alignment(Base):
6
+ """
7
+ Alignment configure file, which contains training parameters of alignment.
8
+ """
9
+
10
+ def __init__(self, args):
11
+ super(Alignment, self).__init__('alignment')
12
+ self.ckpt_dir = '/mnt/workspace/humanAIGC/project/STAR/weights'
13
+ self.net = "stackedHGnet_v1"
14
+ self.nstack = 4
15
+ self.loader_type = "alignment"
16
+ self.data_definition = "300W" # COFW, 300W, WFLW
17
+ self.test_file = "test.tsv"
18
+
19
+ # image
20
+ self.channels = 3
21
+ self.width = 256
22
+ self.height = 256
23
+ self.means = (127.5, 127.5, 127.5)
24
+ self.scale = 1 / 127.5
25
+ self.aug_prob = 1.0
26
+
27
+ self.display_iteration = 10
28
+ self.val_epoch = 1
29
+ self.valset = "test.tsv"
30
+ self.norm_type = 'default'
31
+ self.encoder_type = 'default'
32
+ self.decoder_type = 'default'
33
+
34
+ # scheduler & optimizer
35
+ self.milestones = [200, 350, 450]
36
+ self.max_epoch = 260
37
+ self.optimizer = "adam"
38
+ self.learn_rate = 0.001
39
+ self.weight_decay = 0.00001
40
+ self.betas = [0.9, 0.999]
41
+ self.gamma = 0.1
42
+
43
+ # batch_size & workers
44
+ self.batch_size = 32
45
+ self.train_num_workers = 16
46
+ self.val_batch_size = 32
47
+ self.val_num_workers = 16
48
+ self.test_batch_size = 16
49
+ self.test_num_workers = 0
50
+
51
+ # tricks
52
+ self.ema = True
53
+ self.add_coord = True
54
+ self.use_AAM = True
55
+
56
+ # loss
57
+ self.loss_func = "STARLoss_v2"
58
+
59
+ # STAR Loss paras
60
+ self.star_w = 1
61
+ self.star_dist = 'smoothl1'
62
+
63
+ self.init_from_args(args)
64
+
65
+ # COFW
66
+ if self.data_definition == "COFW":
67
+ self.edge_info = (
68
+ (True, (0, 4, 2, 5)), # RightEyebrow
69
+ (True, (1, 6, 3, 7)), # LeftEyebrow
70
+ (True, (8, 12, 10, 13)), # RightEye
71
+ (False, (9, 14, 11, 15)), # LeftEye
72
+ (True, (18, 20, 19, 21)), # Nose
73
+ (True, (22, 26, 23, 27)), # LowerLip
74
+ (True, (22, 24, 23, 25)), # UpperLip
75
+ )
76
+ if self.norm_type == 'ocular':
77
+ self.nme_left_index = 8 # ocular
78
+ self.nme_right_index = 9 # ocular
79
+ elif self.norm_type in ['pupil', 'default']:
80
+ self.nme_left_index = 16 # pupil
81
+ self.nme_right_index = 17 # pupil
82
+ else:
83
+ raise NotImplementedError
84
+ self.classes_num = [29, 7, 29]
85
+ self.crop_op = True
86
+ self.flip_mapping = (
87
+ [0, 1], [4, 6], [2, 3], [5, 7], [8, 9], [10, 11], [12, 14], [16, 17], [13, 15], [18, 19], [22, 23],
88
+ )
89
+ self.image_dir = osp.join(self.image_dir, 'COFW')
90
+ # 300W
91
+ elif self.data_definition == "300W":
92
+ self.edge_info = (
93
+ (False, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)), # FaceContour
94
+ (False, (17, 18, 19, 20, 21)), # RightEyebrow
95
+ (False, (22, 23, 24, 25, 26)), # LeftEyebrow
96
+ (False, (27, 28, 29, 30)), # NoseLine
97
+ (False, (31, 32, 33, 34, 35)), # Nose
98
+ (True, (36, 37, 38, 39, 40, 41)), # RightEye
99
+ (True, (42, 43, 44, 45, 46, 47)), # LeftEye
100
+ (True, (48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59)), # OuterLip
101
+ (True, (60, 61, 62, 63, 64, 65, 66, 67)), # InnerLip
102
+ )
103
+ if self.norm_type in ['ocular', 'default']:
104
+ self.nme_left_index = 36 # ocular
105
+ self.nme_right_index = 45 # ocular
106
+ elif self.norm_type == 'pupil':
107
+ self.nme_left_index = [36, 37, 38, 39, 40, 41] # pupil
108
+ self.nme_right_index = [42, 43, 44, 45, 46, 47] # pupil
109
+ else:
110
+ raise NotImplementedError
111
+ self.classes_num = [68, 9, 68]
112
+ self.crop_op = True
113
+ self.flip_mapping = (
114
+ [0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11], [6, 10], [7, 9],
115
+ [17, 26], [18, 25], [19, 24], [20, 23], [21, 22],
116
+ [31, 35], [32, 34],
117
+ [36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46],
118
+ [48, 54], [49, 53], [50, 52], [61, 63], [60, 64], [67, 65], [58, 56], [59, 55],
119
+ )
120
+ self.image_dir = osp.join(self.image_dir, '300W')
121
+ # self.image_dir = osp.join(self.image_dir, '300VW_images')
122
+ # 300VW
123
+ elif self.data_definition == "300VW":
124
+ self.edge_info = (
125
+ (False, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)), # FaceContour
126
+ (False, (17, 18, 19, 20, 21)), # RightEyebrow
127
+ (False, (22, 23, 24, 25, 26)), # LeftEyebrow
128
+ (False, (27, 28, 29, 30)), # NoseLine
129
+ (False, (31, 32, 33, 34, 35)), # Nose
130
+ (True, (36, 37, 38, 39, 40, 41)), # RightEye
131
+ (True, (42, 43, 44, 45, 46, 47)), # LeftEye
132
+ (True, (48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59)), # OuterLip
133
+ (True, (60, 61, 62, 63, 64, 65, 66, 67)), # InnerLip
134
+ )
135
+ if self.norm_type in ['ocular', 'default']:
136
+ self.nme_left_index = 36 # ocular
137
+ self.nme_right_index = 45 # ocular
138
+ elif self.norm_type == 'pupil':
139
+ self.nme_left_index = [36, 37, 38, 39, 40, 41] # pupil
140
+ self.nme_right_index = [42, 43, 44, 45, 46, 47] # pupil
141
+ else:
142
+ raise NotImplementedError
143
+ self.classes_num = [68, 9, 68]
144
+ self.crop_op = True
145
+ self.flip_mapping = (
146
+ [0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11], [6, 10], [7, 9],
147
+ [17, 26], [18, 25], [19, 24], [20, 23], [21, 22],
148
+ [31, 35], [32, 34],
149
+ [36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46],
150
+ [48, 54], [49, 53], [50, 52], [61, 63], [60, 64], [67, 65], [58, 56], [59, 55],
151
+ )
152
+ self.image_dir = osp.join(self.image_dir, '300VW_Dataset_2015_12_14')
153
+ # WFLW
154
+ elif self.data_definition == "WFLW":
155
+ self.edge_info = (
156
+ (False, (
157
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
158
+ 27,
159
+ 28, 29, 30, 31, 32)), # FaceContour
160
+ (True, (33, 34, 35, 36, 37, 38, 39, 40, 41)), # RightEyebrow
161
+ (True, (42, 43, 44, 45, 46, 47, 48, 49, 50)), # LeftEyebrow
162
+ (False, (51, 52, 53, 54)), # NoseLine
163
+ (False, (55, 56, 57, 58, 59)), # Nose
164
+ (True, (60, 61, 62, 63, 64, 65, 66, 67)), # RightEye
165
+ (True, (68, 69, 70, 71, 72, 73, 74, 75)), # LeftEye
166
+ (True, (76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87)), # OuterLip
167
+ (True, (88, 89, 90, 91, 92, 93, 94, 95)), # InnerLip
168
+ )
169
+ if self.norm_type in ['ocular', 'default']:
170
+ self.nme_left_index = 60 # ocular
171
+ self.nme_right_index = 72 # ocular
172
+ elif self.norm_type == 'pupil':
173
+ self.nme_left_index = 96 # pupils
174
+ self.nme_right_index = 97 # pupils
175
+ else:
176
+ raise NotImplementedError
177
+ self.classes_num = [98, 9, 98]
178
+ self.crop_op = True
179
+ self.flip_mapping = (
180
+ [0, 32], [1, 31], [2, 30], [3, 29], [4, 28], [5, 27], [6, 26], [7, 25], [8, 24], [9, 23], [10, 22],
181
+ [11, 21], [12, 20], [13, 19], [14, 18], [15, 17], # cheek
182
+ [33, 46], [34, 45], [35, 44], [36, 43], [37, 42], [38, 50], [39, 49], [40, 48], [41, 47], # elbrow
183
+ [60, 72], [61, 71], [62, 70], [63, 69], [64, 68], [65, 75], [66, 74], [67, 73],
184
+ [55, 59], [56, 58],
185
+ [76, 82], [77, 81], [78, 80], [87, 83], [86, 84],
186
+ [88, 92], [89, 91], [95, 93], [96, 97]
187
+ )
188
+ self.image_dir = osp.join(self.image_dir, 'WFLW', 'WFLW_images')
189
+
190
+ self.label_num = self.nstack * 3 if self.use_AAM else self.nstack
191
+ self.loss_weights, self.criterions, self.metrics = [], [], []
192
+ for i in range(self.nstack):
193
+ factor = (2 ** i) / (2 ** (self.nstack - 1))
194
+ if self.use_AAM:
195
+ self.loss_weights += [factor * weight for weight in [1.0, 10.0, 10.0]]
196
+ self.criterions += [self.loss_func, "AWingLoss", "AWingLoss"]
197
+ self.metrics += ["NME", None, None]
198
+ else:
199
+ self.loss_weights += [factor * weight for weight in [1.0]]
200
+ self.criterions += [self.loss_func, ]
201
+ self.metrics += ["NME", ]
202
+
203
+ self.key_metric_index = (self.nstack - 1) * 3 if self.use_AAM else (self.nstack - 1)
204
+
205
+ # data
206
+ self.folder = self.get_foldername()
207
+ self.work_dir = osp.join(self.ckpt_dir, self.data_definition, self.folder)
208
+ self.model_dir = osp.join(self.work_dir, 'model')
209
+ self.log_dir = osp.join(self.work_dir, 'log')
210
+
211
+ self.train_tsv_file = osp.join(self.annot_dir, self.data_definition, "train.tsv")
212
+ self.train_pic_dir = self.image_dir
213
+
214
+ self.val_tsv_file = osp.join(self.annot_dir, self.data_definition, self.valset)
215
+ self.val_pic_dir = self.image_dir
216
+
217
+ self.test_tsv_file = osp.join(self.annot_dir, self.data_definition, self.test_file)
218
+ self.test_pic_dir = self.image_dir
219
+
220
+ # self.train_tsv_file = osp.join(self.annot_dir, '300VW', "train.tsv")
221
+ # self.train_pic_dir = self.image_dir
222
+
223
+ # self.val_tsv_file = osp.join(self.annot_dir, '300VW', self.valset)
224
+ # self.val_pic_dir = self.image_dir
225
+
226
+ # self.test_tsv_file = osp.join(self.annot_dir, '300VW', self.test_file)
227
+ # self.test_pic_dir = self.image_dir
228
+
229
+
230
+ def get_foldername(self):
231
+ str = ''
232
+ str += '{}_{}x{}_{}_ep{}_lr{}_bs{}'.format(self.data_definition, self.height, self.width,
233
+ self.optimizer, self.max_epoch, self.learn_rate, self.batch_size)
234
+ str += '_{}'.format(self.loss_func)
235
+ str += '_{}_{}'.format(self.star_dist, self.star_w) if self.loss_func == 'STARLoss' else ''
236
+ str += '_AAM' if self.use_AAM else ''
237
+ str += '_{}'.format(self.valset[:-4]) if self.valset != 'test.tsv' else ''
238
+ str += '_{}'.format(self.id)
239
+ return str
external/landmark_detection/conf/base.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import logging
3
+ import os.path as osp
4
+ from argparse import Namespace
5
+ # from tensorboardX import SummaryWriter
6
+
7
+ class Base:
8
+ """
9
+ Base configure file, which contains the basic training parameters and should be inherited by other attribute configure file.
10
+ """
11
+
12
+ def __init__(self, config_name, ckpt_dir='./', image_dir='./', annot_dir='./'):
13
+ self.type = config_name
14
+ self.id = str(uuid.uuid4())
15
+ self.note = ""
16
+
17
+ self.ckpt_dir = ckpt_dir
18
+ self.image_dir = image_dir
19
+ self.annot_dir = annot_dir
20
+
21
+ self.loader_type = "alignment"
22
+ self.loss_func = "STARLoss"
23
+
24
+ # train
25
+ self.batch_size = 128
26
+ self.val_batch_size = 1
27
+ self.test_batch_size = 32
28
+ self.channels = 3
29
+ self.width = 256
30
+ self.height = 256
31
+
32
+ # mean values in r, g, b channel.
33
+ self.means = (127, 127, 127)
34
+ self.scale = 0.0078125
35
+
36
+ self.display_iteration = 100
37
+ self.milestones = [50, 80]
38
+ self.max_epoch = 100
39
+
40
+ self.net = "stackedHGnet_v1"
41
+ self.nstack = 4
42
+
43
+ # ["adam", "sgd"]
44
+ self.optimizer = "adam"
45
+ self.learn_rate = 0.1
46
+ self.momentum = 0.01 # caffe: 0.99
47
+ self.weight_decay = 0.0
48
+ self.nesterov = False
49
+ self.scheduler = "MultiStepLR"
50
+ self.gamma = 0.1
51
+
52
+ self.loss_weights = [1.0]
53
+ self.criterions = ["SoftmaxWithLoss"]
54
+ self.metrics = ["Accuracy"]
55
+ self.key_metric_index = 0
56
+ self.classes_num = [1000]
57
+ self.label_num = len(self.classes_num)
58
+
59
+ # model
60
+ self.ema = False
61
+ self.use_AAM = True
62
+
63
+ # visualization
64
+ self.writer = None
65
+
66
+ # log file
67
+ self.logger = None
68
+
69
+ def init_instance(self):
70
+ # self.writer = SummaryWriter(logdir=self.log_dir, comment=self.type)
71
+ log_formatter = logging.Formatter("%(asctime)s %(levelname)-8s: %(message)s")
72
+ root_logger = logging.getLogger()
73
+ file_handler = logging.FileHandler(osp.join(self.log_dir, "log.txt"))
74
+ file_handler.setFormatter(log_formatter)
75
+ file_handler.setLevel(logging.NOTSET)
76
+ root_logger.addHandler(file_handler)
77
+ console_handler = logging.StreamHandler()
78
+ console_handler.setFormatter(log_formatter)
79
+ console_handler.setLevel(logging.NOTSET)
80
+ root_logger.addHandler(console_handler)
81
+ root_logger.setLevel(logging.NOTSET)
82
+ self.logger = root_logger
83
+
84
+ def __del__(self):
85
+ # tensorboard --logdir self.log_dir
86
+ if self.writer is not None:
87
+ # self.writer.export_scalars_to_json(self.log_dir + "visual.json")
88
+ self.writer.close()
89
+
90
+ def init_from_args(self, args: Namespace):
91
+ args_vars = vars(args)
92
+ for key, value in args_vars.items():
93
+ if hasattr(self, key) and value is not None:
94
+ setattr(self, key, value)
external/landmark_detection/config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Token":"bpt4JPotFA6bpdknR9ZDCw",
3
+ "business_flag": "shadow_cv_face",
4
+ "model_local_file_path": "/apdcephfs_cq3/share_1134483/charlinzhou/Documents/awesome-tools/jizhi/",
5
+ "host_num": 1,
6
+ "host_gpu_num": 1,
7
+ "GPUName": "V100",
8
+ "is_elasticity": true,
9
+ "enable_evicted_pulled_up": true,
10
+ "task_name": "20230312_slpt_star_bb_init_eigen_box_align_smoothl1-1",
11
+ "task_flag": "20230312_slpt_star_bb_init_eigen_box_align_smoothl1-1",
12
+ "model_name": "20230312_slpt_star_bb_init_eigen_box_align_smoothl1-1",
13
+ "image_full_name": "mirrors.tencent.com/haroldzcli/py36-pytorch1.7.1-torchvision0.8.2-cuda10.1-cudnn7.6",
14
+ "start_cmd": "./start_slpt.sh /apdcephfs_cq3/share_1134483/charlinzhou/Documents/SLPT_Training train.py --loss_func=star --bb_init --eigen_box --dist_func=align_smoothl1"
15
+ }
external/landmark_detection/data_processor/CheckFaceKeyPoint.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ selected_indices_old = [
8
+ 2311,
9
+ 2416,
10
+ 2437,
11
+ 2460,
12
+ 2495,
13
+ 2518,
14
+ 2520,
15
+ 2627,
16
+ 4285,
17
+ 4315,
18
+ 6223,
19
+ 6457,
20
+ 6597,
21
+ 6642,
22
+ 6974,
23
+ 7054,
24
+ 7064,
25
+ 7182,
26
+ 7303,
27
+ 7334,
28
+ 7351,
29
+ 7368,
30
+ 7374,
31
+ 7493,
32
+ 7503,
33
+ 7626,
34
+ 8443,
35
+ 8562,
36
+ 8597,
37
+ 8701,
38
+ 8817,
39
+ 8953,
40
+ 11213,
41
+ 11261,
42
+ 11317,
43
+ 11384,
44
+ 11600,
45
+ 11755,
46
+ 11852,
47
+ 11891,
48
+ 11945,
49
+ 12010,
50
+ 12354,
51
+ 12534,
52
+ 12736,
53
+ 12880,
54
+ 12892,
55
+ 13004,
56
+ 13323,
57
+ 13371,
58
+ 13534,
59
+ 13575,
60
+ 14874,
61
+ 14949,
62
+ 14977,
63
+ 15052,
64
+ 15076,
65
+ 15291,
66
+ 15620,
67
+ 15758,
68
+ 16309,
69
+ 16325,
70
+ 16348,
71
+ 16390,
72
+ 16489,
73
+ 16665,
74
+ 16891,
75
+ 17147,
76
+ 17183,
77
+ 17488,
78
+ 17549,
79
+ 17657,
80
+ 17932,
81
+ 19661,
82
+ 20162,
83
+ 20200,
84
+ 20238,
85
+ 20286,
86
+ 20432,
87
+ 20834,
88
+ 20954,
89
+ 21015,
90
+ 21036,
91
+ 21117,
92
+ 21299,
93
+ 21611,
94
+ 21632,
95
+ 21649,
96
+ 22722,
97
+ 22759,
98
+ 22873,
99
+ 23028,
100
+ 23033,
101
+ 23082,
102
+ 23187,
103
+ 23232,
104
+ 23302,
105
+ 23413,
106
+ 23430,
107
+ 23446,
108
+ 23457,
109
+ 23548,
110
+ 23636,
111
+ 32060,
112
+ 32245,
113
+ ]
114
+
115
+ selected_indices = list()
116
+ with open('/home/gyalex/Desktop/face_anno.txt', 'r') as f:
117
+ lines = f.readlines()
118
+ for line in lines:
119
+ hh = line.strip().split()
120
+ if len(hh) > 0:
121
+ pid = hh[0].find('.')
122
+ if pid != -1:
123
+ s = hh[0][pid+1:len(hh[0])]
124
+ print(s)
125
+ selected_indices.append(int(s))
126
+
127
+ f.close()
128
+
129
+ dir = '/media/gyalex/Data/face_ldk_dataset/MHC_LightingPreset_Portrait_RT_0_19/MHC_LightingPreset_Portrait_RT_seq_000015'
130
+
131
+ for idx in range(500):
132
+ img = os.path.join(dir, "view_1/MHC_LightingPreset_Portrait_RT_seq_000015_FinalImage_" + str(idx).zfill(4) + ".jpeg")
133
+ lmd = os.path.join(dir, "mesh/mesh_screen" + str(idx+5).zfill(7) + ".npy")
134
+
135
+ img = cv2.imread(img)
136
+ # c = 511 / 2
137
+ # lmd = np.load(lmd) * c + c
138
+ # lmd[:, 1] = 511 - lmd[:, 1]
139
+ lmd = np.load(lmd)[selected_indices]
140
+ for i in range(lmd.shape[0]):
141
+ p = lmd[i]
142
+ x, y = round(float(p[0])), round(float(p[1]))
143
+ print(p)
144
+ cv2.circle(img, (x, y), 2, (0, 0, 255), -1)
145
+
146
+ cv2.imshow('win', img)
147
+ cv2.waitKey(0)
external/landmark_detection/data_processor/align.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import open3d as o3d
3
+ from scipy.spatial.transform import Rotation
4
+ from scipy.linalg import orthogonal_procrustes
5
+
6
+ from open3d.pipelines.registration import registration_ransac_based_on_correspondence
7
+
8
+
9
+ def rigid_transform_3D(A, B):
10
+ assert A.shape == B.shape, "Input arrays must have the same shape"
11
+ assert A.shape[1] == 3, "Input arrays must be Nx3"
12
+
13
+ N = A.shape[0] # Number of points
14
+
15
+ # Compute centroids of A and B
16
+ centroid_A = np.mean(A, axis=0)
17
+ centroid_B = np.mean(B, axis=0)
18
+
19
+ # Center the points around the centroids
20
+ AA = A - centroid_A
21
+ BB = B - centroid_B
22
+
23
+ # H = AA^T * BB
24
+ H = np.dot(AA.T, BB)
25
+
26
+ # Singular Value Decomposition
27
+ U, S, Vt = np.linalg.svd(H)
28
+
29
+ # Compute rotation
30
+ R = np.dot(Vt.T, U.T)
31
+
32
+ # Ensure a proper rotation (det(R) should be +1)
33
+ if np.linalg.det(R) < 0:
34
+ Vt[2, :] *= -1
35
+ R = np.dot(Vt.T, U.T)
36
+
37
+ # Compute translation
38
+ t = centroid_B - np.dot(R, centroid_A)
39
+
40
+ # Construct the transform matrix (4x4)
41
+ transform_matrix = np.eye(4)
42
+ transform_matrix[:3, :3] = R
43
+ transform_matrix[:3, 3] = t
44
+
45
+ return transform_matrix
46
+
47
+
48
+ def compute_rigid_transform(points1, points2):
49
+ """
50
+ 计算从points1到points2的刚体变换(包括尺度、旋转和平移)。
51
+
52
+ 参数:
53
+ points1, points2: np.ndarray, 形状为(68, 3)的数组,分别为两组3D对应点。
54
+
55
+ 返回:
56
+ scale: float, 尺度因子
57
+ R: np.ndarray, 3x3的旋转矩阵
58
+ t: np.ndarray, 3维的平移向量
59
+ """
60
+ # 中心化
61
+ mean1 = np.mean(points1, axis=0)
62
+ centered_points1 = points1 - mean1
63
+ mean2 = np.mean(points2, axis=0)
64
+ centered_points2 = points2 - mean2
65
+
66
+ # 使用orthogonal_procrustes计算旋转和平移
67
+ R, _ = orthogonal_procrustes(centered_points1, centered_points2)
68
+ t = mean2 - R @ mean1 # 计算平移向量
69
+
70
+ # 计算尺度因子
71
+ scale = np.mean(np.linalg.norm(centered_points2, axis=1) /
72
+ np.linalg.norm(centered_points1, axis=1))
73
+
74
+ return scale, R, t
75
+
76
+
77
+ def compute_rigid_transform_new(points_A, points_B):
78
+ # 中心化
79
+ center_A = np.mean(points_A, axis=0)
80
+ center_B = np.mean(points_B, axis=0)
81
+ points_A_centered = points_A - center_A
82
+ points_B_centered = points_B - center_B
83
+
84
+ # 计算协方差矩阵
85
+ cov_matrix = np.dot(points_A_centered.T, points_B_centered)
86
+
87
+ # SVD分解
88
+ U, S, Vt = np.linalg.svd(cov_matrix)
89
+
90
+ # 确保旋转矩阵为正交且右手系,这里我们取Vt的转置作为旋转矩阵
91
+ rotation_matrix = np.dot(Vt.T, U.T)
92
+
93
+ # 检查行列式是否为-1(表示反射,不满足旋转矩阵要求),如果是,则调整一个列的符号
94
+ if np.linalg.det(rotation_matrix) < 0:
95
+ Vt[2,:] *= -1
96
+ rotation_matrix = np.dot(Vt.T, U.T)
97
+
98
+ # 计算尺度因子
99
+ scale = np.trace(np.dot(points_A_centered.T, points_B_centered)) / np.trace(np.dot(points_A_centered.T, points_A_centered))
100
+
101
+ # 计算平移向量
102
+ translation_vector = center_B - scale * np.dot(rotation_matrix, center_A)
103
+
104
+ return scale, rotation_matrix, translation_vector
105
+
106
+
107
+
108
+
109
+ # 示范用法
110
+ obj_A = '/home/gyalex/Desktop/our_face.obj'
111
+ obj_B = '/home/gyalex/Desktop/Neutral.obj'
112
+
113
+ mesh_A = o3d.io.read_triangle_mesh(obj_A)
114
+ mesh_B = o3d.io.read_triangle_mesh(obj_B)
115
+
116
+ vertices_A = np.asarray(mesh_A.vertices)
117
+ vertices_B = np.asarray(mesh_B.vertices)
118
+
119
+ list_A = list()
120
+ list_B = list()
121
+ with open('/home/gyalex/Desktop/our_marker.txt', 'r') as f:
122
+ lines_A = f.readlines()
123
+ for line in lines_A:
124
+ hh = line.strip().split()
125
+ list_A.append(int(hh[0]))
126
+
127
+ with open('/home/gyalex/Desktop/ARKit_landmarks.txt', 'r') as f:
128
+ lines_B = f.readlines()
129
+ for line in lines_B:
130
+ hh = line.strip().split()
131
+ list_B.append(int(hh[0]))
132
+
133
+ A = vertices_A[list_A,:] # 第一组3D点
134
+ B = vertices_B[list_B,:] # 第二组3D点
135
+
136
+ # scale, R, t = compute_rigid_transform(A, B)
137
+
138
+ # # 定义尺度变换矩阵
139
+ # scale_matrix = np.eye(4)
140
+ # scale_matrix[0, 0] = scale # x轴方向放大2倍
141
+ # scale_matrix[1, 1] = scale # y轴方向放大2倍
142
+ # scale_matrix[2, 2] = scale # z轴方向放大2倍
143
+
144
+ # transform_matrix = np.eye(4)
145
+ # transform_matrix[:3, :3] = scale
146
+ # transform_matrix[:3, 3] = R*t
147
+
148
+ # mesh_A.transform(transform_matrix)
149
+ # # mesh_A.transform(scale_matrix)
150
+
151
+ # o3d.io.write_triangle_mesh('/home/gyalex/Desktop/our_face_new.obj', mesh_A)
152
+
153
+ pcd_source = o3d.utility.Vector3dVector(A) # 示例源点云数据
154
+ pcd_target = o3d.utility.Vector3dVector(B) # 示例目标点云数据 + 1偏移,仅作示例
155
+
156
+ corres_source = list()
157
+ for idx in range(68): corres_source.append(idx)
158
+ corres_target = list()
159
+ for idx in range(68): corres_target.append(idx)
160
+
161
+ # 根据对应点索引获取实际的对应点坐标
162
+ corres_source_points = pcd_source
163
+ corres_target_points = pcd_target
164
+
165
+ corres = o3d.utility.Vector2iVector([[src, tgt] for src, tgt in zip(corres_source, corres_target)])
166
+
167
+ # 应用RANSAC进行基于对应点的配准
168
+ reg_result = registration_ransac_based_on_correspondence(
169
+ pcd_source,
170
+ pcd_target,
171
+ corres,
172
+ estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(),
173
+ ransac_n=3,
174
+ criteria=o3d.pipelines.registration.RANSACConvergenceCriteria(max_iteration=100000, epsilon=1e-6)
175
+ )
176
+
177
+ # # 使用RANSAC进行配准
178
+ # convergence_criteria = o3d.pipelines.registration.RANSACConvergenceCriteria(max_iteration=50000, max_validation=500)
179
+ # ransac_result = o3d.pipelines.registration.registration_ransac_based_on_correspondence(
180
+ # pcd_source,
181
+ # pcd_target,
182
+ # corres,
183
+ # o3d.pipelines.registration.TransformationEstimationPointToPoint(),
184
+ # 3, # RANSAC阈值,根据实际情况调整
185
+ # convergence_criteria,
186
+ # [o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9),
187
+ # o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(0.05)],
188
+ # o3d.pipelines.registration.RANSACLoss())
189
+
190
+ # 应用变换到源mesh
191
+ # mesh_source_aligned = mesh_source.transform(reg_result.transformation)
192
+
193
+ a = 0
external/landmark_detection/data_processor/process_pcd.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import open3d as o3d
5
+ # import pyrender
6
+ # from pyrender import mesh, DirectionalLight, Material, PerspectiveCamera
7
+
8
+ os.environ['__GL_THREADED_OPTIMIZATIONS'] = '1'
9
+
10
+ cord_list = []
11
+ with open('./cord.txt', 'r') as f:
12
+ lines = f.readlines()
13
+ for line in lines:
14
+ m = line.split()
15
+ x = int(m[0])
16
+ y = int(m[1])
17
+
18
+ x = 1000 - x
19
+ y = 1000 - y
20
+
21
+ cord_list.append([x, y])
22
+
23
+
24
+ # 假设TXT文件的路径
25
+ output_folder = '/media/gyalex/Data/face_det_dataset/rgbd_data/rgbd'
26
+ if not os.path.exists(output_folder):
27
+ os.mkdir(output_folder)
28
+
29
+ for idx in range(32, 33):
30
+ txt_file_path = '/media/gyalex/Data/face_det_dataset/rgbd_data/PointImage'+ str(idx) + '.txt'
31
+ _, name = os.path.split(txt_file_path)
32
+ print(txt_file_path)
33
+
34
+ with open(txt_file_path, 'r') as file:
35
+ points = []
36
+ rgb_list = []
37
+ ori_rgb_list = []
38
+ normal_list = []
39
+
40
+ # 逐行读取数据
41
+ for line in file:
42
+ # 去除行尾的换行符并分割字符串
43
+ x, y, z, r, g, b, nx, ny, nz, w = line.split()
44
+ # 将字符串转换为浮点数
45
+ x = float(x)
46
+ y = float(y)
47
+ z = float(z)
48
+ r = float(r)
49
+ g = float(g)
50
+ b = float(b)
51
+ nx = float(nx)
52
+ ny = float(ny)
53
+ nz = float(nz)
54
+ # 将点添加到列表中
55
+ points.append((x, y, z))
56
+ rgb_list.append((r/255.0, g/255.0 , b/255.0))
57
+ normal_list.append((nx, ny, nz))
58
+
59
+ ori_r = int(r)
60
+ ori_g = int(g)
61
+ ori_b = int(b)
62
+ ori_rgb_list.append((ori_r, ori_g , ori_b))
63
+
64
+ np_points = np.asarray(points)
65
+
66
+ np_points_a = np_points
67
+
68
+ np_colors = np.asarray(rgb_list)
69
+ np_normals = np.asarray(normal_list)
70
+
71
+ np_colors_ori = np.asarray(ori_rgb_list)
72
+
73
+ pcd = o3d.geometry.PointCloud()
74
+ pcd.points = o3d.utility.Vector3dVector(np_points)
75
+ pcd.colors = o3d.utility.Vector3dVector(np_colors)
76
+ pcd.normals = o3d.utility.Vector3dVector(np_normals)
77
+
78
+ map_dict = {}
79
+
80
+ image = np.ones((1000, 1000, 3),dtype=np.uint8)*255
81
+ for i in range(np.array(pcd.points).shape[0]):
82
+ x = np.array(pcd.points)[i,0]+400
83
+ y = np.array(pcd.points)[i,1]+400
84
+
85
+ image[int(x),int(y),:] = (np.array(pcd.colors)[i,:]*255).astype(np.uint8)
86
+ image[int(x+1),int(y),:] = (np.array(pcd.colors)[i,:]*255).astype(np.uint8)
87
+ image[int(x),int(y+1),:] = (np.array(pcd.colors)[i,:]*255).astype(np.uint8)
88
+ image[int(x-1),int(y),:] = (np.array(pcd.colors)[i,:]*255).astype(np.uint8)
89
+ image[int(x),int(y-1),:] = (np.array(pcd.colors)[i,:]*255).astype(np.uint8)
90
+
91
+ map_dict[str(int(x)) + '_' + str(int(y))] = i
92
+ map_dict[str(int(x+1)) + '_' + str(int(y))] = i
93
+ map_dict[str(int(x)) + '_' + str(int(y+1))] = i
94
+ map_dict[str(int(x-1)) + '_' + str(int(y))] = i
95
+ map_dict[str(int(x)) + '_' + str(int(y-1))] = i
96
+
97
+ # if [int(y), int(x)] in cord_list:
98
+ # image[int(x),int(y),:] = np.array([0, 255, 0])
99
+
100
+ # if [int(y), int(x+1)] in cord_list:
101
+ # image[int(x+1),int(y),:] = np.array([0, 255, 0])
102
+
103
+ # if [int(y+1), int(x)] in cord_list:
104
+ # image[int(x),int(y+1),:] = np.array([0, 255, 0])
105
+
106
+ # if [int(y), int(x-1)] in cord_list:
107
+ # image[int(x-1),int(y),:] = np.array([0, 255, 0])
108
+
109
+ # if [int(y-1), int(x)] in cord_list:
110
+ # image[int(x),int(y-1),:] = np.array([0, 255, 0])
111
+
112
+ # if [int(y-1), int(x-1)] in cord_list:
113
+ # image[int(x-1),int(y-1),:] = np.array([0, 255, 0])
114
+
115
+ # if [int(y+1), int(x+1)] in cord_list:
116
+ # image[int(x+1),int(y+1),:] = np.array([0, 255, 0])
117
+
118
+ h_list = []
119
+ for m in cord_list:
120
+ a, b = m[0], m[1]
121
+ c = image[int(b),int(a),:][0]
122
+
123
+ flag = False
124
+
125
+ if image[int(b),int(a),:][1] != 255:
126
+ h_list.append(str(int(b))+'_'+str(int(a)))
127
+ flag = True
128
+ else:
129
+ if image[int(b)-2,int(a)-2,:][1] != 255:
130
+ h_list.append(str(int(b)-2)+'_'+str(int(a)-2))
131
+ flag = True
132
+ elif image[int(b)+2,int(a)+2,:][1] != 255:
133
+ h_list.append(str(int(b)+2)+'_'+str(int(a)+2))
134
+ flag = True
135
+ elif image[int(b),int(a)-3,:][1] != 255:
136
+ h_list.append(str(int(b))+'_'+str(int(a)-3))
137
+ flag = True
138
+
139
+ # if flag == False:
140
+ # cc = image[int(b),int(a),:][1]
141
+
142
+ # cv2.circle(image, (465,505), 2, (0, 255, 0), -1)
143
+
144
+ # cv2.imshow('win', image)
145
+ # cv2.waitKey(0)
146
+
147
+ with open('pid.txt', 'w') as f:
148
+ for h in h_list:
149
+ pid = map_dict[h]
150
+ s = str(pid) + '\n'
151
+ f.write(s)
152
+
153
+ np_colors[pid,:] = np.array([0, 255, 0])
154
+
155
+ f.close()
156
+
157
+ pcd0 = o3d.geometry.PointCloud()
158
+ pcd0.points = o3d.utility.Vector3dVector(np_points)
159
+ pcd0.colors = o3d.utility.Vector3dVector(np_colors)
160
+ pcd0.normals = o3d.utility.Vector3dVector(np_normals)
161
+
162
+ o3d.io.write_point_cloud('aa.ply', pcd0)
163
+
164
+
165
+ mm = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
166
+ image3 = cv2.flip(mm, -1)
167
+
168
+ # cv2.imwrite('./rgb.png', image3)
169
+
170
+ with open('./cord.txt', 'r') as f:
171
+ lines = f.readlines()
172
+ for line in lines:
173
+ m = line.split()
174
+ x = int(m[0])
175
+ y = int(m[1])
176
+
177
+ x = 1000 - x
178
+ y = 1000 - y
179
+
180
+ cv2.circle(image, (x,y), 2, (0, 255, 0), -1)
181
+
182
+ idx = map_dict[str(x)+'_'+str(y)]
183
+
184
+ a = 0
185
+
186
+ # cv2.imshow("win", image)
187
+ # cv2.waitKey(0)
188
+
189
+
190
+
191
+
192
+
193
+
194
+
195
+
196
+
197
+
198
+
199
+
200
+
201
+
202
+ # import matplotlib.pyplot as plt
203
+ # plt.imshow(image)
204
+ # plt.show()
205
+
206
+ # save_pcd_path = os.path.join(output_folder, name[:-3]+'ply')
207
+ # # o3d.io.write_point_cloud(save_pcd_path, pcd)
208
+
209
+ # # render
210
+ # import trimesh
211
+ # # fuze_trimesh = trimesh.load('/home/gyalex/Desktop/PointImage32.obj')
212
+ # # mesh = pyrender.Mesh.from_trimesh(fuze_trimesh)
213
+ # mesh = pyrender.Mesh.from_points(np_points, np_colors_ori, np_normals)
214
+
215
+ # import math
216
+ # camera = PerspectiveCamera(yfov=math.pi / 3, aspectRatio=1.0)
217
+ # camera_pose = np.array([[-1.0, 0.0, 0.0, 0], \
218
+ # [0.0, 1.0, 0.0, 0], \
219
+ # [0.0, 0.0, -1.0, 0], \
220
+ # [0.0, 0.0, 0.0, 1.0]])
221
+
222
+ # # 创建场景
223
+ # scene = pyrender.Scene()
224
+ # scene.add(mesh)
225
+ # scene.add(camera, pose=camera_pose)
226
+
227
+ # # light = pyrender.SpotLight(color=np.ones(3), intensity=3.0, innerConeAngle=np.pi/16.0, outerConeAngle=np.pi/6.0)
228
+ # # scene.add(light, pose=camera_pose)
229
+
230
+ # # 渲染场景
231
+ # renderer = pyrender.OffscreenRenderer(viewport_width=1280, viewport_height=1024)
232
+ # color, depth = renderer.render(scene)
233
+
234
+ # # # 设置场景和光源
235
+ # # scene = pyrender.Scene()
236
+ # # scene.add(point_cloud_mesh, 'point_cloud')
237
+ # # camera = PerspectiveCamera(yfov=45.0, aspectRatio=1.0)
238
+ # # scene.add(camera)
239
+
240
+ # # # 渲染场景
241
+ # # renderer = pyrender.OffscreenRenderer(viewport_width=1280, viewport_height=1024)
242
+ # # color, depth = renderer.render(scene)
243
+
244
+ # # 保存渲染结果为图片
245
+ # import cv2
246
+ # cv2.imshow('win', color)
247
+
248
+ # rgb_img = cv2.imread('/media/gyalex/Data/face_det_dataset/rgbd_data/color_32.bmp')
249
+ # cv2.imshow('win0', rgb_img)
250
+ # cv2.waitKey(0)
external/landmark_detection/evaluate.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import math
4
+ import argparse
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+
8
+ import torch
9
+
10
+ # private package
11
+ from lib import utility
12
+
13
+
14
+
15
+ class GetCropMatrix():
16
+ """
17
+ from_shape -> transform_matrix
18
+ """
19
+
20
+ def __init__(self, image_size, target_face_scale, align_corners=False):
21
+ self.image_size = image_size
22
+ self.target_face_scale = target_face_scale
23
+ self.align_corners = align_corners
24
+
25
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
26
+ cosv = math.cos(angle)
27
+ sinv = math.sin(angle)
28
+
29
+ fx, fy = from_center
30
+ tx, ty = to_center
31
+
32
+ acos = scale * cosv
33
+ asin = scale * sinv
34
+
35
+ a0 = acos
36
+ a1 = -asin
37
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
38
+
39
+ b0 = asin
40
+ b1 = acos
41
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
42
+
43
+ rot_scale_m = np.array([
44
+ [a0, a1, a2],
45
+ [b0, b1, b2],
46
+ [0.0, 0.0, 1.0]
47
+ ], np.float32)
48
+ return rot_scale_m
49
+
50
+ def process(self, scale, center_w, center_h):
51
+ if self.align_corners:
52
+ to_w, to_h = self.image_size - 1, self.image_size - 1
53
+ else:
54
+ to_w, to_h = self.image_size, self.image_size
55
+
56
+ rot_mu = 0
57
+ scale_mu = self.image_size / (scale * self.target_face_scale * 200.0)
58
+ shift_xy_mu = (0, 0)
59
+ matrix = self._compose_rotate_and_scale(
60
+ rot_mu, scale_mu, shift_xy_mu,
61
+ from_center=[center_w, center_h],
62
+ to_center=[to_w / 2.0, to_h / 2.0])
63
+ return matrix
64
+
65
+
66
+ class TransformPerspective():
67
+ """
68
+ image, matrix3x3 -> transformed_image
69
+ """
70
+
71
+ def __init__(self, image_size):
72
+ self.image_size = image_size
73
+
74
+ def process(self, image, matrix):
75
+ return cv2.warpPerspective(
76
+ image, matrix, dsize=(self.image_size, self.image_size),
77
+ flags=cv2.INTER_LINEAR, borderValue=0)
78
+
79
+
80
+ class TransformPoints2D():
81
+ """
82
+ points (nx2), matrix (3x3) -> points (nx2)
83
+ """
84
+
85
+ def process(self, srcPoints, matrix):
86
+ # nx3
87
+ desPoints = np.concatenate([srcPoints, np.ones_like(srcPoints[:, [0]])], axis=1)
88
+ desPoints = desPoints @ np.transpose(matrix) # nx3
89
+ desPoints = desPoints[:, :2] / desPoints[:, [2, 2]]
90
+ return desPoints.astype(srcPoints.dtype)
91
+
92
+
93
+ class Alignment:
94
+ def __init__(self, args, model_path, dl_framework, device_ids):
95
+ self.input_size = 256
96
+ self.target_face_scale = 1.0
97
+ self.dl_framework = dl_framework
98
+
99
+ # model
100
+ if self.dl_framework == "pytorch":
101
+ # conf
102
+ self.config = utility.get_config(args)
103
+ self.config.device_id = device_ids[0]
104
+ # set environment
105
+ utility.set_environment(self.config)
106
+ self.config.init_instance()
107
+ if self.config.logger is not None:
108
+ self.config.logger.info("Loaded configure file %s: %s" % (args.config_name, self.config.id))
109
+ self.config.logger.info("\n" + "\n".join(["%s: %s" % item for item in self.config.__dict__.items()]))
110
+
111
+ net = utility.get_net(self.config)
112
+ if device_ids == [-1]:
113
+ checkpoint = torch.load(model_path, map_location="cpu")
114
+ else:
115
+ checkpoint = torch.load(model_path)
116
+ net.load_state_dict(checkpoint["net"])
117
+ net = net.to(self.config.device_id)
118
+ net.eval()
119
+ self.alignment = net
120
+ else:
121
+ assert False
122
+
123
+ self.getCropMatrix = GetCropMatrix(image_size=self.input_size, target_face_scale=self.target_face_scale,
124
+ align_corners=True)
125
+ self.transformPerspective = TransformPerspective(image_size=self.input_size)
126
+ self.transformPoints2D = TransformPoints2D()
127
+
128
+ def norm_points(self, points, align_corners=False):
129
+ if align_corners:
130
+ # [0, SIZE-1] -> [-1, +1]
131
+ return points / torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) * 2 - 1
132
+ else:
133
+ # [-0.5, SIZE-0.5] -> [-1, +1]
134
+ return (points * 2 + 1) / torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1
135
+
136
+ def denorm_points(self, points, align_corners=False):
137
+ if align_corners:
138
+ # [-1, +1] -> [0, SIZE-1]
139
+ return (points + 1) / 2 * torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2)
140
+ else:
141
+ # [-1, +1] -> [-0.5, SIZE-0.5]
142
+ return ((points + 1) * torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1) / 2
143
+
144
+ def preprocess(self, image, scale, center_w, center_h):
145
+ matrix = self.getCropMatrix.process(scale, center_w, center_h)
146
+ input_tensor = self.transformPerspective.process(image, matrix)
147
+ input_tensor = input_tensor[np.newaxis, :]
148
+
149
+ input_tensor = torch.from_numpy(input_tensor)
150
+ input_tensor = input_tensor.float().permute(0, 3, 1, 2)
151
+ input_tensor = input_tensor / 255.0 * 2.0 - 1.0
152
+ input_tensor = input_tensor.to(self.config.device_id)
153
+ return input_tensor, matrix
154
+
155
+ def postprocess(self, srcPoints, coeff):
156
+ # dstPoints = self.transformPoints2D.process(srcPoints, coeff)
157
+ # matrix^(-1) * src = dst
158
+ # src = matrix * dst
159
+ dstPoints = np.zeros(srcPoints.shape, dtype=np.float32)
160
+ for i in range(srcPoints.shape[0]):
161
+ dstPoints[i][0] = coeff[0][0] * srcPoints[i][0] + coeff[0][1] * srcPoints[i][1] + coeff[0][2]
162
+ dstPoints[i][1] = coeff[1][0] * srcPoints[i][0] + coeff[1][1] * srcPoints[i][1] + coeff[1][2]
163
+ return dstPoints
164
+
165
+ def analyze(self, image, scale, center_w, center_h):
166
+ input_tensor, matrix = self.preprocess(image, scale, center_w, center_h)
167
+
168
+ if self.dl_framework == "pytorch":
169
+ with torch.no_grad():
170
+ output = self.alignment(input_tensor)
171
+ landmarks = output[-1][0]
172
+ else:
173
+ assert False
174
+
175
+ landmarks = self.denorm_points(landmarks)
176
+ landmarks = landmarks.data.cpu().numpy()[0]
177
+ landmarks = self.postprocess(landmarks, np.linalg.inv(matrix))
178
+
179
+ return landmarks
180
+
181
+
182
+ def L2(p1, p2):
183
+ return np.linalg.norm(p1 - p2)
184
+
185
+
186
+ def NME(landmarks_gt, landmarks_pv):
187
+ pts_num = landmarks_gt.shape[0]
188
+ if pts_num == 29:
189
+ left_index = 16
190
+ right_index = 17
191
+ elif pts_num == 68:
192
+ left_index = 36
193
+ right_index = 45
194
+ elif pts_num == 98:
195
+ left_index = 60
196
+ right_index = 72
197
+
198
+ nme = 0
199
+ eye_span = L2(landmarks_gt[left_index], landmarks_gt[right_index])
200
+ for i in range(pts_num):
201
+ error = L2(landmarks_pv[i], landmarks_gt[i])
202
+ nme += error / eye_span
203
+ nme /= pts_num
204
+ return nme
205
+
206
+
207
+ def evaluate(args, model_path, metadata_path, device_ids, mode):
208
+ alignment = Alignment(args, model_path, dl_framework="pytorch", device_ids=device_ids)
209
+ config = alignment.config
210
+ nme_sum = 0
211
+ with open(metadata_path, 'r') as f:
212
+ lines = f.readlines()
213
+ for k, line in enumerate(tqdm(lines)):
214
+ item = line.strip().split("\t")
215
+ image_name, landmarks_5pts, landmarks_gt, scale, center_w, center_h = item[:6]
216
+ # image & keypoints alignment
217
+ image_name = image_name.replace('\\', '/')
218
+ image_name = image_name.replace('//msr-facestore/Workspace/MSRA_EP_Allergan/users/yanghuan/training_data/wflw/rawImages/', '')
219
+ image_name = image_name.replace('./rawImages/', '')
220
+ image_path = os.path.join(config.image_dir, image_name)
221
+ landmarks_gt = np.array(list(map(float, landmarks_gt.split(","))), dtype=np.float32).reshape(-1, 2)
222
+ scale, center_w, center_h = float(scale), float(center_w), float(center_h)
223
+
224
+ image = cv2.imread(image_path)
225
+ landmarks_pv = alignment.analyze(image, scale, center_w, center_h)
226
+
227
+ # NME
228
+ if mode == "nme":
229
+ nme = NME(landmarks_gt, landmarks_pv)
230
+ nme_sum += nme
231
+ # print("Current NME(%d): %f" % (k + 1, (nme_sum / (k + 1))))
232
+ else:
233
+ pass
234
+
235
+ if mode == "nme":
236
+ print("Final NME: %f" % (100*nme_sum / (k + 1)))
237
+ else:
238
+ pass
239
+
240
+
241
+ if __name__ == "__main__":
242
+ parser = argparse.ArgumentParser(description="Evaluation script")
243
+ parser.add_argument("--config_name", type=str, default="alignment", help="set configure file name")
244
+ parser.add_argument("--model_path", type=str, default="./train.pkl", help="the path of model")
245
+ parser.add_argument("--data_definition", type=str, default='WFLW', help="COFW/300W/WFLW")
246
+ parser.add_argument("--metadata_path", type=str, default="", help="the path of metadata")
247
+ parser.add_argument("--image_dir", type=str, default="", help="the path of image")
248
+ parser.add_argument("--device_ids", type=str, default="0", help="set device ids, -1 means use cpu device, >= 0 means use gpu device")
249
+ parser.add_argument("--mode", type=str, default="nme", help="set the evaluate mode: nme")
250
+ args = parser.parse_args()
251
+
252
+ device_ids = list(map(int, args.device_ids.split(",")))
253
+ evaluate(
254
+ args,
255
+ model_path=args.model_path,
256
+ metadata_path=args.metadata_path,
257
+ device_ids=device_ids,
258
+ mode=args.mode)
external/landmark_detection/infer_folder.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import copy
4
+ import numpy as np
5
+ import argparse
6
+ import torch
7
+ import json
8
+
9
+ # private package
10
+ from lib import utility
11
+ from FaceBoxesV2.faceboxes_detector import *
12
+
13
+ class GetCropMatrix():
14
+ """
15
+ from_shape -> transform_matrix
16
+ """
17
+
18
+ def __init__(self, image_size, target_face_scale, align_corners=False):
19
+ self.image_size = image_size
20
+ self.target_face_scale = target_face_scale
21
+ self.align_corners = align_corners
22
+
23
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
24
+ cosv = math.cos(angle)
25
+ sinv = math.sin(angle)
26
+
27
+ fx, fy = from_center
28
+ tx, ty = to_center
29
+
30
+ acos = scale * cosv
31
+ asin = scale * sinv
32
+
33
+ a0 = acos
34
+ a1 = -asin
35
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
36
+
37
+ b0 = asin
38
+ b1 = acos
39
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
40
+
41
+ rot_scale_m = np.array([
42
+ [a0, a1, a2],
43
+ [b0, b1, b2],
44
+ [0.0, 0.0, 1.0]
45
+ ], np.float32)
46
+ return rot_scale_m
47
+
48
+ def process(self, scale, center_w, center_h):
49
+ if self.align_corners:
50
+ to_w, to_h = self.image_size - 1, self.image_size - 1
51
+ else:
52
+ to_w, to_h = self.image_size, self.image_size
53
+
54
+ rot_mu = 0
55
+ scale_mu = self.image_size / (scale * self.target_face_scale * 200.0)
56
+ shift_xy_mu = (0, 0)
57
+ matrix = self._compose_rotate_and_scale(
58
+ rot_mu, scale_mu, shift_xy_mu,
59
+ from_center=[center_w, center_h],
60
+ to_center=[to_w / 2.0, to_h / 2.0])
61
+ return matrix
62
+
63
+
64
+ class TransformPerspective():
65
+ """
66
+ image, matrix3x3 -> transformed_image
67
+ """
68
+
69
+ def __init__(self, image_size):
70
+ self.image_size = image_size
71
+
72
+ def process(self, image, matrix):
73
+ return cv2.warpPerspective(
74
+ image, matrix, dsize=(self.image_size, self.image_size),
75
+ flags=cv2.INTER_LINEAR, borderValue=0)
76
+
77
+
78
+ class TransformPoints2D():
79
+ """
80
+ points (nx2), matrix (3x3) -> points (nx2)
81
+ """
82
+
83
+ def process(self, srcPoints, matrix):
84
+ # nx3
85
+ desPoints = np.concatenate([srcPoints, np.ones_like(srcPoints[:, [0]])], axis=1)
86
+ desPoints = desPoints @ np.transpose(matrix) # nx3
87
+ desPoints = desPoints[:, :2] / desPoints[:, [2, 2]]
88
+ return desPoints.astype(srcPoints.dtype)
89
+
90
+ class Alignment:
91
+ def __init__(self, args, model_path, dl_framework, device_ids):
92
+ self.input_size = 256
93
+ self.target_face_scale = 1.0
94
+ self.dl_framework = dl_framework
95
+
96
+ # model
97
+ if self.dl_framework == "pytorch":
98
+ # conf
99
+ self.config = utility.get_config(args)
100
+ self.config.device_id = device_ids[0]
101
+ # set environment
102
+ utility.set_environment(self.config)
103
+ # self.config.init_instance()
104
+ # if self.config.logger is not None:
105
+ # self.config.logger.info("Loaded configure file %s: %s" % (args.config_name, self.config.id))
106
+ # self.config.logger.info("\n" + "\n".join(["%s: %s" % item for item in self.config.__dict__.items()]))
107
+
108
+ net = utility.get_net(self.config)
109
+ if device_ids == [-1]:
110
+ checkpoint = torch.load(model_path, map_location="cpu")
111
+ else:
112
+ checkpoint = torch.load(model_path)
113
+ net.load_state_dict(checkpoint["net"])
114
+
115
+ if self.config.device_id == -1:
116
+ net = net.cpu()
117
+ else:
118
+ net = net.to(self.config.device_id)
119
+
120
+ net.eval()
121
+ self.alignment = net
122
+ else:
123
+ assert False
124
+
125
+ self.getCropMatrix = GetCropMatrix(image_size=self.input_size, target_face_scale=self.target_face_scale,
126
+ align_corners=True)
127
+ self.transformPerspective = TransformPerspective(image_size=self.input_size)
128
+ self.transformPoints2D = TransformPoints2D()
129
+
130
+ def norm_points(self, points, align_corners=False):
131
+ if align_corners:
132
+ # [0, SIZE-1] -> [-1, +1]
133
+ return points / torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) * 2 - 1
134
+ else:
135
+ # [-0.5, SIZE-0.5] -> [-1, +1]
136
+ return (points * 2 + 1) / torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1
137
+
138
+ def denorm_points(self, points, align_corners=False):
139
+ if align_corners:
140
+ # [-1, +1] -> [0, SIZE-1]
141
+ return (points + 1) / 2 * torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2)
142
+ else:
143
+ # [-1, +1] -> [-0.5, SIZE-0.5]
144
+ return ((points + 1) * torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1) / 2
145
+
146
+ def preprocess(self, image, scale, center_w, center_h):
147
+ matrix = self.getCropMatrix.process(scale, center_w, center_h)
148
+ input_tensor = self.transformPerspective.process(image, matrix)
149
+ input_tensor = input_tensor[np.newaxis, :]
150
+
151
+ input_tensor = torch.from_numpy(input_tensor)
152
+ input_tensor = input_tensor.float().permute(0, 3, 1, 2)
153
+ input_tensor = input_tensor / 255.0 * 2.0 - 1.0
154
+
155
+ if self.config.device_id == -1:
156
+ input_tensor = input_tensor.cpu()
157
+ else:
158
+ input_tensor = input_tensor.to(self.config.device_id)
159
+
160
+ return input_tensor, matrix
161
+
162
+ def postprocess(self, srcPoints, coeff):
163
+ # dstPoints = self.transformPoints2D.process(srcPoints, coeff)
164
+ # matrix^(-1) * src = dst
165
+ # src = matrix * dst
166
+ dstPoints = np.zeros(srcPoints.shape, dtype=np.float32)
167
+ for i in range(srcPoints.shape[0]):
168
+ dstPoints[i][0] = coeff[0][0] * srcPoints[i][0] + coeff[0][1] * srcPoints[i][1] + coeff[0][2]
169
+ dstPoints[i][1] = coeff[1][0] * srcPoints[i][0] + coeff[1][1] * srcPoints[i][1] + coeff[1][2]
170
+ return dstPoints
171
+
172
+ def analyze(self, image, scale, center_w, center_h):
173
+ input_tensor, matrix = self.preprocess(image, scale, center_w, center_h)
174
+
175
+ if self.dl_framework == "pytorch":
176
+ with torch.no_grad():
177
+ output = self.alignment(input_tensor)
178
+ landmarks = output[-1][0]
179
+ else:
180
+ assert False
181
+
182
+ landmarks = self.denorm_points(landmarks)
183
+ landmarks = landmarks.data.cpu().numpy()[0]
184
+ landmarks = self.postprocess(landmarks, np.linalg.inv(matrix))
185
+
186
+ return landmarks
187
+
188
+ if __name__ == '__main__':
189
+ parser = argparse.ArgumentParser(description="inference script")
190
+ parser.add_argument('--folder_path', type=str, help='Path to image folder')
191
+ args = parser.parse_args()
192
+
193
+ # args.folder_path = '/media/gyalex/Data/flame/ph_test/head_images/flame/image'
194
+
195
+ current_path = os.getcwd()
196
+
197
+ use_gpu = True
198
+ ########### face detection ############
199
+ if use_gpu:
200
+ device = torch.device("cuda:0")
201
+ else:
202
+ device = torch.device("cpu")
203
+
204
+ current_path = os.getcwd()
205
+ det_model_path = os.path.join(current_path, 'preprocess', 'submodules', 'Landmark_detection', 'FaceBoxesV2/weights/FaceBoxesV2.pth')
206
+ detector = FaceBoxesDetector('FaceBoxes', det_model_path, use_gpu, device)
207
+
208
+ ########### facial alignment ############
209
+ model_path = os.path.join(current_path, 'preprocess', 'submodules', 'Landmark_detection', 'weights/68_keypoints_model.pkl')
210
+
211
+ if use_gpu:
212
+ device_ids = [0]
213
+ else:
214
+ device_ids = [-1]
215
+
216
+ args.config_name = 'alignment'
217
+ alignment = Alignment(args, model_path, dl_framework="pytorch", device_ids=device_ids)
218
+
219
+ img_path_list = os.listdir(args.folder_path)
220
+ kpts_code = dict()
221
+
222
+ ########### inference ############
223
+ for file_name in img_path_list:
224
+ abs_path = os.path.join(args.folder_path, file_name)
225
+
226
+ image = cv2.imread(abs_path)
227
+ image_draw = copy.deepcopy(image)
228
+
229
+ detections, _ = detector.detect(image, 0.6, 1)
230
+ for idx in range(len(detections)):
231
+ x1_ori = detections[idx][2]
232
+ y1_ori = detections[idx][3]
233
+ x2_ori = x1_ori + detections[idx][4]
234
+ y2_ori = y1_ori + detections[idx][5]
235
+
236
+ scale = max(x2_ori - x1_ori, y2_ori - y1_ori) / 180
237
+ center_w = (x1_ori + x2_ori) / 2
238
+ center_h = (y1_ori + y2_ori) / 2
239
+ scale, center_w, center_h = float(scale), float(center_w), float(center_h)
240
+
241
+ landmarks_pv = alignment.analyze(image, scale, center_w, center_h)
242
+ landmarks_pv_list = landmarks_pv.tolist()
243
+
244
+ for num in range(landmarks_pv.shape[0]):
245
+ cv2.circle(image_draw, (round(landmarks_pv[num][0]), round(landmarks_pv[num][1])),
246
+ 2, (0, 255, 0), -1)
247
+
248
+ kpts_code[file_name] = landmarks_pv_list
249
+ save_path = args.folder_path[:-5] + 'landmark'
250
+ cv2.imwrite(os.path.join(save_path, file_name), image_draw)
251
+
252
+ path = args.folder_path[:-5]
253
+ json.dump(kpts_code, open(os.path.join(path, 'keypoint.json'), 'w'))
external/landmark_detection/infer_image.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import copy
4
+ import numpy as np
5
+ import argparse
6
+ import torch
7
+
8
+ # private package
9
+ from external.landmark_detection.lib import utility
10
+ from external.landmark_detection.FaceBoxesV2.faceboxes_detector import *
11
+
12
+ class GetCropMatrix():
13
+ """
14
+ from_shape -> transform_matrix
15
+ """
16
+
17
+ def __init__(self, image_size, target_face_scale, align_corners=False):
18
+ self.image_size = image_size
19
+ self.target_face_scale = target_face_scale
20
+ self.align_corners = align_corners
21
+
22
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
23
+ cosv = math.cos(angle)
24
+ sinv = math.sin(angle)
25
+
26
+ fx, fy = from_center
27
+ tx, ty = to_center
28
+
29
+ acos = scale * cosv
30
+ asin = scale * sinv
31
+
32
+ a0 = acos
33
+ a1 = -asin
34
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
35
+
36
+ b0 = asin
37
+ b1 = acos
38
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
39
+
40
+ rot_scale_m = np.array([
41
+ [a0, a1, a2],
42
+ [b0, b1, b2],
43
+ [0.0, 0.0, 1.0]
44
+ ], np.float32)
45
+ return rot_scale_m
46
+
47
+ def process(self, scale, center_w, center_h):
48
+ if self.align_corners:
49
+ to_w, to_h = self.image_size - 1, self.image_size - 1
50
+ else:
51
+ to_w, to_h = self.image_size, self.image_size
52
+
53
+ rot_mu = 0
54
+ scale_mu = self.image_size / (scale * self.target_face_scale * 200.0)
55
+ shift_xy_mu = (0, 0)
56
+ matrix = self._compose_rotate_and_scale(
57
+ rot_mu, scale_mu, shift_xy_mu,
58
+ from_center=[center_w, center_h],
59
+ to_center=[to_w / 2.0, to_h / 2.0])
60
+ return matrix
61
+
62
+
63
+ class TransformPerspective():
64
+ """
65
+ image, matrix3x3 -> transformed_image
66
+ """
67
+
68
+ def __init__(self, image_size):
69
+ self.image_size = image_size
70
+
71
+ def process(self, image, matrix):
72
+ return cv2.warpPerspective(
73
+ image, matrix, dsize=(self.image_size, self.image_size),
74
+ flags=cv2.INTER_LINEAR, borderValue=0)
75
+
76
+
77
+ class TransformPoints2D():
78
+ """
79
+ points (nx2), matrix (3x3) -> points (nx2)
80
+ """
81
+
82
+ def process(self, srcPoints, matrix):
83
+ # nx3
84
+ desPoints = np.concatenate([srcPoints, np.ones_like(srcPoints[:, [0]])], axis=1)
85
+ desPoints = desPoints @ np.transpose(matrix) # nx3
86
+ desPoints = desPoints[:, :2] / desPoints[:, [2, 2]]
87
+ return desPoints.astype(srcPoints.dtype)
88
+
89
+ class Alignment:
90
+ def __init__(self, args, model_path, dl_framework, device_ids):
91
+ self.input_size = 256
92
+ self.target_face_scale = 1.0
93
+ self.dl_framework = dl_framework
94
+
95
+ # model
96
+ if self.dl_framework == "pytorch":
97
+ # conf
98
+ self.config = utility.get_config(args)
99
+ self.config.device_id = device_ids[0]
100
+ # set environment
101
+ # utility.set_environment(self.config)
102
+ # self.config.init_instance()
103
+ # if self.config.logger is not None:
104
+ # self.config.logger.info("Loaded configure file %s: %s" % (args.config_name, self.config.id))
105
+ # self.config.logger.info("\n" + "\n".join(["%s: %s" % item for item in self.config.__dict__.items()]))
106
+
107
+ net = utility.get_net(self.config)
108
+ if device_ids == [-1]:
109
+ checkpoint = torch.load(model_path, map_location="cpu")
110
+ else:
111
+ checkpoint = torch.load(model_path)
112
+ net.load_state_dict(checkpoint["net"])
113
+
114
+ if self.config.device_id == -1:
115
+ net = net.cpu()
116
+ else:
117
+ net = net.to(self.config.device_id)
118
+
119
+ net.eval()
120
+ self.alignment = net
121
+ else:
122
+ assert False
123
+
124
+ self.getCropMatrix = GetCropMatrix(image_size=self.input_size, target_face_scale=self.target_face_scale,
125
+ align_corners=True)
126
+ self.transformPerspective = TransformPerspective(image_size=self.input_size)
127
+ self.transformPoints2D = TransformPoints2D()
128
+
129
+ def norm_points(self, points, align_corners=False):
130
+ if align_corners:
131
+ # [0, SIZE-1] -> [-1, +1]
132
+ return points / torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) * 2 - 1
133
+ else:
134
+ # [-0.5, SIZE-0.5] -> [-1, +1]
135
+ return (points * 2 + 1) / torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1
136
+
137
+ def denorm_points(self, points, align_corners=False):
138
+ if align_corners:
139
+ # [-1, +1] -> [0, SIZE-1]
140
+ return (points + 1) / 2 * torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2)
141
+ else:
142
+ # [-1, +1] -> [-0.5, SIZE-0.5]
143
+ return ((points + 1) * torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1) / 2
144
+
145
+ def preprocess(self, image, scale, center_w, center_h):
146
+ matrix = self.getCropMatrix.process(scale, center_w, center_h)
147
+ input_tensor = self.transformPerspective.process(image, matrix)
148
+ input_tensor = input_tensor[np.newaxis, :]
149
+
150
+ input_tensor = torch.from_numpy(input_tensor)
151
+ input_tensor = input_tensor.float().permute(0, 3, 1, 2)
152
+ input_tensor = input_tensor / 255.0 * 2.0 - 1.0
153
+
154
+ if self.config.device_id == -1:
155
+ input_tensor = input_tensor.cpu()
156
+ else:
157
+ input_tensor = input_tensor.to(self.config.device_id)
158
+
159
+ return input_tensor, matrix
160
+
161
+ def postprocess(self, srcPoints, coeff):
162
+ # dstPoints = self.transformPoints2D.process(srcPoints, coeff)
163
+ # matrix^(-1) * src = dst
164
+ # src = matrix * dst
165
+ dstPoints = np.zeros(srcPoints.shape, dtype=np.float32)
166
+ for i in range(srcPoints.shape[0]):
167
+ dstPoints[i][0] = coeff[0][0] * srcPoints[i][0] + coeff[0][1] * srcPoints[i][1] + coeff[0][2]
168
+ dstPoints[i][1] = coeff[1][0] * srcPoints[i][0] + coeff[1][1] * srcPoints[i][1] + coeff[1][2]
169
+ return dstPoints
170
+
171
+ def analyze(self, image, scale, center_w, center_h):
172
+ input_tensor, matrix = self.preprocess(image, scale, center_w, center_h)
173
+
174
+ if self.dl_framework == "pytorch":
175
+ with torch.no_grad():
176
+ output = self.alignment(input_tensor)
177
+ landmarks = output[-1][0]
178
+ else:
179
+ assert False
180
+
181
+ landmarks = self.denorm_points(landmarks)
182
+ landmarks = landmarks.data.cpu().numpy()[0]
183
+ landmarks = self.postprocess(landmarks, np.linalg.inv(matrix))
184
+
185
+ return landmarks
186
+
187
+ # parser = argparse.ArgumentParser(description="Evaluation script")
188
+ # args = parser.parse_args()
189
+ # image_path = './rgb.png'
190
+ # image = cv2.imread(image_path)
191
+ #
192
+ # use_gpu = False
193
+ # ########### face detection ############
194
+ # if use_gpu:
195
+ # device = torch.device("cuda:0")
196
+ # else:
197
+ # device = torch.device("cpu")
198
+ #
199
+ # detector = FaceBoxesDetector('FaceBoxes', 'FaceBoxesV2/weights/FaceBoxesV2.pth', use_gpu, device)
200
+ #
201
+ # ########### facial alignment ############
202
+ # model_path = './weights/68_keypoints_model.pkl'
203
+ #
204
+ # if use_gpu:
205
+ # device_ids = [0]
206
+ # else:
207
+ # device_ids = [-1]
208
+ #
209
+ # args.config_name = 'alignment'
210
+ # alignment = Alignment(args, model_path, dl_framework="pytorch", device_ids=device_ids)
211
+ # image_draw = copy.deepcopy(image)
212
+ #
213
+ # ########### inference ############
214
+ # ldk_list = []
215
+ #
216
+ # detections, _ = detector.detect(image, 0.9, 1)
217
+ # for idx in range(len(detections)):
218
+ # x1_ori = detections[idx][2]
219
+ # y1_ori = detections[idx][3]
220
+ # x2_ori = x1_ori + detections[idx][4]
221
+ # y2_ori = y1_ori + detections[idx][5]
222
+ #
223
+ # scale = max(x2_ori - x1_ori, y2_ori - y1_ori) / 180
224
+ # center_w = (x1_ori + x2_ori) / 2
225
+ # center_h = (y1_ori + y2_ori) / 2
226
+ # scale, center_w, center_h = float(scale), float(center_w), float(center_h)
227
+ #
228
+ # landmarks_pv = alignment.analyze(image, scale, center_w, center_h)
229
+ #
230
+ # for num in range(landmarks_pv.shape[0]):
231
+ # cv2.circle(image_draw, (round(landmarks_pv[num][0]), round(landmarks_pv[num][1])),
232
+ # 2, (0, 255, 0), -1)
233
+ #
234
+ # ldk_list.append([round(landmarks_pv[num][0]), round(landmarks_pv[num][1])])
235
+ #
236
+ # cv2.imshow("win", image_draw)
237
+ #
238
+ # # ldk_img = cv2.imread('/home/gyalex/Desktop/image_landmark_149/all.jpg')
239
+ # # cv2.imshow("win1", ldk_img)
240
+ #
241
+ # cv2.waitKey(0)
242
+ #
243
+ # with open('./cord.txt', 'w') as f:
244
+ # for num in range(len(ldk_list)):
245
+ # s = str(ldk_list[num][0]) + ' ' + str(ldk_list[num][1]) + '\n'
246
+ # f.write(s)
247
+ #
248
+ # f.close()
249
+
250
+
251
+
external/landmark_detection/infer_video.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import copy
4
+ import numpy as np
5
+ import argparse
6
+ import torch
7
+ import json
8
+
9
+ # private package
10
+ from lib import utility
11
+ from FaceBoxesV2.faceboxes_detector import *
12
+
13
+ class GetCropMatrix():
14
+ """
15
+ from_shape -> transform_matrix
16
+ """
17
+
18
+ def __init__(self, image_size, target_face_scale, align_corners=False):
19
+ self.image_size = image_size
20
+ self.target_face_scale = target_face_scale
21
+ self.align_corners = align_corners
22
+
23
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
24
+ cosv = math.cos(angle)
25
+ sinv = math.sin(angle)
26
+
27
+ fx, fy = from_center
28
+ tx, ty = to_center
29
+
30
+ acos = scale * cosv
31
+ asin = scale * sinv
32
+
33
+ a0 = acos
34
+ a1 = -asin
35
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
36
+
37
+ b0 = asin
38
+ b1 = acos
39
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
40
+
41
+ rot_scale_m = np.array([
42
+ [a0, a1, a2],
43
+ [b0, b1, b2],
44
+ [0.0, 0.0, 1.0]
45
+ ], np.float32)
46
+ return rot_scale_m
47
+
48
+ def process(self, scale, center_w, center_h):
49
+ if self.align_corners:
50
+ to_w, to_h = self.image_size - 1, self.image_size - 1
51
+ else:
52
+ to_w, to_h = self.image_size, self.image_size
53
+
54
+ rot_mu = 0
55
+ scale_mu = self.image_size / (scale * self.target_face_scale * 200.0)
56
+ shift_xy_mu = (0, 0)
57
+ matrix = self._compose_rotate_and_scale(
58
+ rot_mu, scale_mu, shift_xy_mu,
59
+ from_center=[center_w, center_h],
60
+ to_center=[to_w / 2.0, to_h / 2.0])
61
+ return matrix
62
+
63
+
64
+ class TransformPerspective():
65
+ """
66
+ image, matrix3x3 -> transformed_image
67
+ """
68
+
69
+ def __init__(self, image_size):
70
+ self.image_size = image_size
71
+
72
+ def process(self, image, matrix):
73
+ return cv2.warpPerspective(
74
+ image, matrix, dsize=(self.image_size, self.image_size),
75
+ flags=cv2.INTER_LINEAR, borderValue=0)
76
+
77
+
78
+ class TransformPoints2D():
79
+ """
80
+ points (nx2), matrix (3x3) -> points (nx2)
81
+ """
82
+
83
+ def process(self, srcPoints, matrix):
84
+ # nx3
85
+ desPoints = np.concatenate([srcPoints, np.ones_like(srcPoints[:, [0]])], axis=1)
86
+ desPoints = desPoints @ np.transpose(matrix) # nx3
87
+ desPoints = desPoints[:, :2] / desPoints[:, [2, 2]]
88
+ return desPoints.astype(srcPoints.dtype)
89
+
90
+ class Alignment:
91
+ def __init__(self, args, model_path, dl_framework, device_ids):
92
+ self.input_size = 256
93
+ self.target_face_scale = 1.0
94
+ self.dl_framework = dl_framework
95
+
96
+ # model
97
+ if self.dl_framework == "pytorch":
98
+ # conf
99
+ self.config = utility.get_config(args)
100
+ self.config.device_id = device_ids[0]
101
+ # set environment
102
+ utility.set_environment(self.config)
103
+ # self.config.init_instance()
104
+ # if self.config.logger is not None:
105
+ # self.config.logger.info("Loaded configure file %s: %s" % (args.config_name, self.config.id))
106
+ # self.config.logger.info("\n" + "\n".join(["%s: %s" % item for item in self.config.__dict__.items()]))
107
+
108
+ net = utility.get_net(self.config)
109
+ if device_ids == [-1]:
110
+ checkpoint = torch.load(model_path, map_location="cpu")
111
+ else:
112
+ checkpoint = torch.load(model_path)
113
+ net.load_state_dict(checkpoint["net"])
114
+
115
+ if self.config.device_id == -1:
116
+ net = net.cpu()
117
+ else:
118
+ net = net.to(self.config.device_id)
119
+
120
+ net.eval()
121
+ self.alignment = net
122
+ else:
123
+ assert False
124
+
125
+ self.getCropMatrix = GetCropMatrix(image_size=self.input_size, target_face_scale=self.target_face_scale,
126
+ align_corners=True)
127
+ self.transformPerspective = TransformPerspective(image_size=self.input_size)
128
+ self.transformPoints2D = TransformPoints2D()
129
+
130
+ def norm_points(self, points, align_corners=False):
131
+ if align_corners:
132
+ # [0, SIZE-1] -> [-1, +1]
133
+ return points / torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) * 2 - 1
134
+ else:
135
+ # [-0.5, SIZE-0.5] -> [-1, +1]
136
+ return (points * 2 + 1) / torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1
137
+
138
+ def denorm_points(self, points, align_corners=False):
139
+ if align_corners:
140
+ # [-1, +1] -> [0, SIZE-1]
141
+ return (points + 1) / 2 * torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2)
142
+ else:
143
+ # [-1, +1] -> [-0.5, SIZE-0.5]
144
+ return ((points + 1) * torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1) / 2
145
+
146
+ def preprocess(self, image, scale, center_w, center_h):
147
+ matrix = self.getCropMatrix.process(scale, center_w, center_h)
148
+ input_tensor = self.transformPerspective.process(image, matrix)
149
+ input_tensor = input_tensor[np.newaxis, :]
150
+
151
+ input_tensor = torch.from_numpy(input_tensor)
152
+ input_tensor = input_tensor.float().permute(0, 3, 1, 2)
153
+ input_tensor = input_tensor / 255.0 * 2.0 - 1.0
154
+
155
+ if self.config.device_id == -1:
156
+ input_tensor = input_tensor.cpu()
157
+ else:
158
+ input_tensor = input_tensor.to(self.config.device_id)
159
+
160
+ return input_tensor, matrix
161
+
162
+ def postprocess(self, srcPoints, coeff):
163
+ # dstPoints = self.transformPoints2D.process(srcPoints, coeff)
164
+ # matrix^(-1) * src = dst
165
+ # src = matrix * dst
166
+ dstPoints = np.zeros(srcPoints.shape, dtype=np.float32)
167
+ for i in range(srcPoints.shape[0]):
168
+ dstPoints[i][0] = coeff[0][0] * srcPoints[i][0] + coeff[0][1] * srcPoints[i][1] + coeff[0][2]
169
+ dstPoints[i][1] = coeff[1][0] * srcPoints[i][0] + coeff[1][1] * srcPoints[i][1] + coeff[1][2]
170
+ return dstPoints
171
+
172
+ def analyze(self, image, scale, center_w, center_h):
173
+ input_tensor, matrix = self.preprocess(image, scale, center_w, center_h)
174
+
175
+ if self.dl_framework == "pytorch":
176
+ with torch.no_grad():
177
+ output = self.alignment(input_tensor)
178
+ landmarks = output[-1][0]
179
+ else:
180
+ assert False
181
+
182
+ landmarks = self.denorm_points(landmarks)
183
+ landmarks = landmarks.data.cpu().numpy()[0]
184
+ landmarks = self.postprocess(landmarks, np.linalg.inv(matrix))
185
+
186
+ return landmarks
187
+
188
+ if __name__ == '__main__':
189
+ parser = argparse.ArgumentParser(description="inference script")
190
+ parser.add_argument('--video_path', type=str, help='Path to videos',default='/media/yuanzhen/HH/DATASET/VFTH/TESTVIDEO/Clip+7CzHzeeVRlE+P0+C0+F101007-101139.mp4')
191
+ args = parser.parse_args()
192
+
193
+ # args.video_path = '/media/gyalex/Data/flame/ph_test/test.mp4'
194
+
195
+ current_path = os.getcwd()
196
+
197
+ use_gpu = True
198
+ ########### face detection ############
199
+ if use_gpu:
200
+ device = torch.device("cuda:0")
201
+ else:
202
+ device = torch.device("cpu")
203
+
204
+ current_path = os.getcwd()
205
+ det_model_path = '/home/yuanzhen/code/landmark_detection/FaceBoxesV2/weights/FaceBoxesV2.pth'
206
+ detector = FaceBoxesDetector('FaceBoxes', det_model_path, use_gpu, device)
207
+
208
+ ########### facial alignment ############
209
+ model_path = '/home/yuanzhen/code/landmark_detection/weights/68_keypoints_model.pkl'
210
+
211
+ if use_gpu:
212
+ device_ids = [0]
213
+ else:
214
+ device_ids = [-1]
215
+
216
+ args.config_name = 'alignment'
217
+ alignment = Alignment(args, model_path, dl_framework="pytorch", device_ids=device_ids)
218
+
219
+ video_file = args.video_path
220
+ cap = cv2.VideoCapture(video_file)
221
+ frame_width = int(cap.get(3))
222
+ frame_height = int(cap.get(4))
223
+
224
+ # out_video_file = './output_video.mp4'
225
+ # fps = 30
226
+ # size = (frame_width, frame_height)
227
+ # out = cv2.VideoWriter(out_video_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, size)
228
+
229
+ count = 0
230
+ kpts_code = dict()
231
+
232
+ keypoint_data_path = args.video_path.replace('.mp4','.json')
233
+ with open(keypoint_data_path,'r') as f:
234
+ keypoint_data = json.load(f)
235
+
236
+ ########### inference ############
237
+ path = video_file[:-4]
238
+ while(cap.isOpened()):
239
+ ret, image = cap.read()
240
+
241
+ if ret:
242
+ detections, _ = detector.detect(image, 0.8, 1)
243
+ image_draw = copy.deepcopy(image)
244
+
245
+ cv2.imwrite(os.path.join(path, 'image', str(count+1)+'.png'), image_draw)
246
+
247
+ for idx in range(len(detections)):
248
+ x1_ori = detections[idx][2]
249
+ y1_ori = detections[idx][3]
250
+ x2_ori = x1_ori + detections[idx][4]
251
+ y2_ori = y1_ori + detections[idx][5]
252
+
253
+ scale = max(x2_ori - x1_ori, y2_ori - y1_ori) / 180
254
+ center_w = (x1_ori + x2_ori) / 2
255
+ center_h = (y1_ori + y2_ori) / 2
256
+ scale, center_w, center_h = float(scale), float(center_w), float(center_h)
257
+
258
+ # landmarks_pv = alignment.analyze(image, scale, center_w, center_h)
259
+ landmarks_pv = np.array(keypoint_data[str(count+1)+'.png'])
260
+
261
+ landmarks_pv_list = landmarks_pv.tolist()
262
+
263
+ for num in range(landmarks_pv.shape[0]):
264
+ cv2.circle(image_draw, (round(landmarks_pv[num][0]), round(landmarks_pv[num][1])),
265
+ 2, (0, 255, 0), -1)
266
+ cv2.putText(image_draw, str(num),
267
+ (round(landmarks_pv[num][0]) + 5, round(landmarks_pv[num][1]) + 5), # 文本位置
268
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1, cv2.LINE_AA)
269
+
270
+ kpts_code[str(count+1)+'.png'] = landmarks_pv_list
271
+ cv2.imwrite(os.path.join(path, 'landmark', str(count+1)+'.png'), image_draw)
272
+ else:
273
+ break
274
+
275
+ count += 1
276
+
277
+ cap.release()
278
+ # out.release()
279
+ # cv2.destroyAllWindows()
280
+
281
+ path = video_file[:-4]
282
+ json.dump(kpts_code, open(os.path.join(path, 'keypoint.json'), 'w'))
283
+
284
+ print(path)
285
+
286
+
287
+
external/landmark_detection/lib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .dataset import get_encoder, get_decoder
2
+ from .dataset import AlignmentDataset, Augmentation
3
+ from .backbone import StackedHGNetV1
4
+ from .metric import NME, Accuracy
5
+ from .utils import time_print, time_string, time_for_file, time_string_short
6
+ from .utils import convert_secs2time, convert_size2str
7
+
8
+ from .utility import get_dataloader, get_config, get_net, get_criterions
9
+ from .utility import get_optimizer, get_scheduler
external/landmark_detection/lib/backbone/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .stackedHGNetV1 import StackedHGNetV1
2
+
3
+ __all__ = [
4
+ "StackedHGNetV1",
5
+ ]
external/landmark_detection/lib/backbone/core/coord_conv.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class AddCoordsTh(nn.Module):
6
+ def __init__(self, x_dim, y_dim, with_r=False, with_boundary=False):
7
+ super(AddCoordsTh, self).__init__()
8
+ self.x_dim = x_dim
9
+ self.y_dim = y_dim
10
+ self.with_r = with_r
11
+ self.with_boundary = with_boundary
12
+
13
+ def forward(self, input_tensor, heatmap=None):
14
+ """
15
+ input_tensor: (batch, c, x_dim, y_dim)
16
+ """
17
+ batch_size_tensor = input_tensor.shape[0]
18
+
19
+ xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32).to(input_tensor)
20
+ xx_ones = xx_ones.unsqueeze(-1)
21
+
22
+ xx_range = torch.arange(self.x_dim, dtype=torch.int32).unsqueeze(0).to(input_tensor)
23
+ xx_range = xx_range.unsqueeze(1)
24
+
25
+ xx_channel = torch.matmul(xx_ones.float(), xx_range.float())
26
+ xx_channel = xx_channel.unsqueeze(-1)
27
+
28
+ yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32).to(input_tensor)
29
+ yy_ones = yy_ones.unsqueeze(1)
30
+
31
+ yy_range = torch.arange(self.y_dim, dtype=torch.int32).unsqueeze(0).to(input_tensor)
32
+ yy_range = yy_range.unsqueeze(-1)
33
+
34
+ yy_channel = torch.matmul(yy_range.float(), yy_ones.float())
35
+ yy_channel = yy_channel.unsqueeze(-1)
36
+
37
+ xx_channel = xx_channel.permute(0, 3, 2, 1)
38
+ yy_channel = yy_channel.permute(0, 3, 2, 1)
39
+
40
+ xx_channel = xx_channel / (self.x_dim - 1)
41
+ yy_channel = yy_channel / (self.y_dim - 1)
42
+
43
+ xx_channel = xx_channel * 2 - 1
44
+ yy_channel = yy_channel * 2 - 1
45
+
46
+ xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1)
47
+ yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1)
48
+
49
+ if self.with_boundary and type(heatmap) != type(None):
50
+ boundary_channel = torch.clamp(heatmap[:, -1:, :, :],
51
+ 0.0, 1.0)
52
+
53
+ zero_tensor = torch.zeros_like(xx_channel).to(xx_channel)
54
+ xx_boundary_channel = torch.where(boundary_channel>0.05,
55
+ xx_channel, zero_tensor)
56
+ yy_boundary_channel = torch.where(boundary_channel>0.05,
57
+ yy_channel, zero_tensor)
58
+ ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)
59
+
60
+
61
+ if self.with_r:
62
+ rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2))
63
+ rr = rr / torch.max(rr)
64
+ ret = torch.cat([ret, rr], dim=1)
65
+
66
+ if self.with_boundary and type(heatmap) != type(None):
67
+ ret = torch.cat([ret, xx_boundary_channel,
68
+ yy_boundary_channel], dim=1)
69
+ return ret
70
+
71
+
72
+ class CoordConvTh(nn.Module):
73
+ """CoordConv layer as in the paper."""
74
+ def __init__(self, x_dim, y_dim, with_r, with_boundary,
75
+ in_channels, out_channels, first_one=False, relu=False, bn=False, *args, **kwargs):
76
+ super(CoordConvTh, self).__init__()
77
+ self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r,
78
+ with_boundary=with_boundary)
79
+ in_channels += 2
80
+ if with_r:
81
+ in_channels += 1
82
+ if with_boundary and not first_one:
83
+ in_channels += 2
84
+ self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, *args, **kwargs)
85
+ self.relu = nn.ReLU() if relu else None
86
+ self.bn = nn.BatchNorm2d(out_channels) if bn else None
87
+
88
+ self.with_boundary = with_boundary
89
+ self.first_one = first_one
90
+
91
+
92
+ def forward(self, input_tensor, heatmap=None):
93
+ assert (self.with_boundary and not self.first_one) == (heatmap is not None)
94
+ ret = self.addcoords(input_tensor, heatmap)
95
+ ret = self.conv(ret)
96
+ if self.bn is not None:
97
+ ret = self.bn(ret)
98
+ if self.relu is not None:
99
+ ret = self.relu(ret)
100
+
101
+ return ret
102
+
103
+
104
+ '''
105
+ An alternative implementation for PyTorch with auto-infering the x-y dimensions.
106
+ '''
107
+ class AddCoords(nn.Module):
108
+
109
+ def __init__(self, with_r=False):
110
+ super().__init__()
111
+ self.with_r = with_r
112
+
113
+ def forward(self, input_tensor):
114
+ """
115
+ Args:
116
+ input_tensor: shape(batch, channel, x_dim, y_dim)
117
+ """
118
+ batch_size, _, x_dim, y_dim = input_tensor.size()
119
+
120
+ xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1).to(input_tensor)
121
+ yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2).to(input_tensor)
122
+
123
+ xx_channel = xx_channel / (x_dim - 1)
124
+ yy_channel = yy_channel / (y_dim - 1)
125
+
126
+ xx_channel = xx_channel * 2 - 1
127
+ yy_channel = yy_channel * 2 - 1
128
+
129
+ xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
130
+ yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
131
+
132
+ ret = torch.cat([
133
+ input_tensor,
134
+ xx_channel.type_as(input_tensor),
135
+ yy_channel.type_as(input_tensor)], dim=1)
136
+
137
+ if self.with_r:
138
+ rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
139
+ ret = torch.cat([ret, rr], dim=1)
140
+
141
+ return ret
142
+
143
+
144
+ class CoordConv(nn.Module):
145
+
146
+ def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
147
+ super().__init__()
148
+ self.addcoords = AddCoords(with_r=with_r)
149
+ in_channels += 2
150
+ if with_r:
151
+ in_channels += 1
152
+ self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
153
+
154
+ def forward(self, x):
155
+ ret = self.addcoords(x)
156
+ ret = self.conv(ret)
157
+ return ret
external/landmark_detection/lib/backbone/stackedHGNetV1.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .core.coord_conv import CoordConvTh
8
+ from external.landmark_detection.lib.dataset import get_decoder
9
+
10
+
11
+
12
+ class Activation(nn.Module):
13
+ def __init__(self, kind: str = 'relu', channel=None):
14
+ super().__init__()
15
+ self.kind = kind
16
+
17
+ if '+' in kind:
18
+ norm_str, act_str = kind.split('+')
19
+ else:
20
+ norm_str, act_str = 'none', kind
21
+
22
+ self.norm_fn = {
23
+ 'in': F.instance_norm,
24
+ 'bn': nn.BatchNorm2d(channel),
25
+ 'bn_noaffine': nn.BatchNorm2d(channel, affine=False, track_running_stats=True),
26
+ 'none': None
27
+ }[norm_str]
28
+
29
+ self.act_fn = {
30
+ 'relu': F.relu,
31
+ 'softplus': nn.Softplus(),
32
+ 'exp': torch.exp,
33
+ 'sigmoid': torch.sigmoid,
34
+ 'tanh': torch.tanh,
35
+ 'none': None
36
+ }[act_str]
37
+
38
+ self.channel = channel
39
+
40
+ def forward(self, x):
41
+ if self.norm_fn is not None:
42
+ x = self.norm_fn(x)
43
+ if self.act_fn is not None:
44
+ x = self.act_fn(x)
45
+ return x
46
+
47
+ def extra_repr(self):
48
+ return f'kind={self.kind}, channel={self.channel}'
49
+
50
+
51
+ class ConvBlock(nn.Module):
52
+ def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True, groups=1):
53
+ super(ConvBlock, self).__init__()
54
+ self.inp_dim = inp_dim
55
+ self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size,
56
+ stride, padding=(kernel_size - 1) // 2, groups=groups, bias=True)
57
+ self.relu = None
58
+ self.bn = None
59
+ if relu:
60
+ self.relu = nn.ReLU()
61
+ if bn:
62
+ self.bn = nn.BatchNorm2d(out_dim)
63
+
64
+ def forward(self, x):
65
+ x = self.conv(x)
66
+ if self.bn is not None:
67
+ x = self.bn(x)
68
+ if self.relu is not None:
69
+ x = self.relu(x)
70
+ return x
71
+
72
+
73
+ class ResBlock(nn.Module):
74
+ def __init__(self, inp_dim, out_dim, mid_dim=None):
75
+ super(ResBlock, self).__init__()
76
+ if mid_dim is None:
77
+ mid_dim = out_dim // 2
78
+ self.relu = nn.ReLU()
79
+ self.bn1 = nn.BatchNorm2d(inp_dim)
80
+ self.conv1 = ConvBlock(inp_dim, mid_dim, 1, relu=False)
81
+ self.bn2 = nn.BatchNorm2d(mid_dim)
82
+ self.conv2 = ConvBlock(mid_dim, mid_dim, 3, relu=False)
83
+ self.bn3 = nn.BatchNorm2d(mid_dim)
84
+ self.conv3 = ConvBlock(mid_dim, out_dim, 1, relu=False)
85
+ self.skip_layer = ConvBlock(inp_dim, out_dim, 1, relu=False)
86
+ if inp_dim == out_dim:
87
+ self.need_skip = False
88
+ else:
89
+ self.need_skip = True
90
+
91
+ def forward(self, x):
92
+ if self.need_skip:
93
+ residual = self.skip_layer(x)
94
+ else:
95
+ residual = x
96
+ out = self.bn1(x)
97
+ out = self.relu(out)
98
+ out = self.conv1(out)
99
+ out = self.bn2(out)
100
+ out = self.relu(out)
101
+ out = self.conv2(out)
102
+ out = self.bn3(out)
103
+ out = self.relu(out)
104
+ out = self.conv3(out)
105
+ out += residual
106
+ return out
107
+
108
+
109
+ class Hourglass(nn.Module):
110
+ def __init__(self, n, f, increase=0, up_mode='nearest',
111
+ add_coord=False, first_one=False, x_dim=64, y_dim=64):
112
+ super(Hourglass, self).__init__()
113
+ nf = f + increase
114
+
115
+ Block = ResBlock
116
+
117
+ if add_coord:
118
+ self.coordconv = CoordConvTh(x_dim=x_dim, y_dim=y_dim,
119
+ with_r=True, with_boundary=True,
120
+ relu=False, bn=False,
121
+ in_channels=f, out_channels=f,
122
+ first_one=first_one,
123
+ kernel_size=1,
124
+ stride=1, padding=0)
125
+ else:
126
+ self.coordconv = None
127
+ self.up1 = Block(f, f)
128
+
129
+ # Lower branch
130
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
131
+
132
+ self.low1 = Block(f, nf)
133
+ self.n = n
134
+ # Recursive hourglass
135
+ if self.n > 1:
136
+ self.low2 = Hourglass(n=n - 1, f=nf, increase=increase, up_mode=up_mode, add_coord=False)
137
+ else:
138
+ self.low2 = Block(nf, nf)
139
+ self.low3 = Block(nf, f)
140
+ self.up2 = nn.Upsample(scale_factor=2, mode=up_mode)
141
+
142
+ def forward(self, x, heatmap=None):
143
+ if self.coordconv is not None:
144
+ x = self.coordconv(x, heatmap)
145
+ up1 = self.up1(x)
146
+ pool1 = self.pool1(x)
147
+ low1 = self.low1(pool1)
148
+ low2 = self.low2(low1)
149
+ low3 = self.low3(low2)
150
+ up2 = self.up2(low3)
151
+ return up1 + up2
152
+
153
+
154
+ class E2HTransform(nn.Module):
155
+ def __init__(self, edge_info, num_points, num_edges):
156
+ super().__init__()
157
+
158
+ e2h_matrix = np.zeros([num_points, num_edges])
159
+ for edge_id, isclosed_indices in enumerate(edge_info):
160
+ is_closed, indices = isclosed_indices
161
+ for point_id in indices:
162
+ e2h_matrix[point_id, edge_id] = 1
163
+ e2h_matrix = torch.from_numpy(e2h_matrix).float()
164
+
165
+ # pn x en x 1 x 1.
166
+ self.register_buffer('weight', e2h_matrix.view(
167
+ e2h_matrix.size(0), e2h_matrix.size(1), 1, 1))
168
+
169
+ # some keypoints are not coverred by any edges,
170
+ # in these cases, we must add a constant bias to their heatmap weights.
171
+ bias = ((e2h_matrix @ torch.ones(e2h_matrix.size(1)).to(
172
+ e2h_matrix)) < 0.5).to(e2h_matrix)
173
+ # pn x 1.
174
+ self.register_buffer('bias', bias)
175
+
176
+ def forward(self, edgemaps):
177
+ # input: batch_size x en x hw x hh.
178
+ # output: batch_size x pn x hw x hh.
179
+ return F.conv2d(edgemaps, weight=self.weight, bias=self.bias)
180
+
181
+
182
+ class StackedHGNetV1(nn.Module):
183
+ def __init__(self, config, classes_num, edge_info,
184
+ nstack=4, nlevels=4, in_channel=256, increase=0,
185
+ add_coord=True, decoder_type='default'):
186
+ super(StackedHGNetV1, self).__init__()
187
+
188
+ self.cfg = config
189
+ self.coder_type = decoder_type
190
+ self.decoder = get_decoder(decoder_type=decoder_type)
191
+ self.nstack = nstack
192
+ self.add_coord = add_coord
193
+
194
+ self.num_heats = classes_num[0]
195
+
196
+ if self.add_coord:
197
+ convBlock = CoordConvTh(x_dim=self.cfg.width, y_dim=self.cfg.height,
198
+ with_r=True, with_boundary=False,
199
+ relu=True, bn=True,
200
+ in_channels=3, out_channels=64,
201
+ kernel_size=7,
202
+ stride=2, padding=3)
203
+ else:
204
+ convBlock = ConvBlock(3, 64, 7, 2, bn=True, relu=True)
205
+
206
+ pool = nn.MaxPool2d(kernel_size=2, stride=2)
207
+
208
+ Block = ResBlock
209
+
210
+ self.pre = nn.Sequential(
211
+ convBlock,
212
+ Block(64, 128),
213
+ pool,
214
+ Block(128, 128),
215
+ Block(128, in_channel)
216
+ )
217
+
218
+ self.hgs = nn.ModuleList(
219
+ [Hourglass(n=nlevels, f=in_channel, increase=increase, add_coord=self.add_coord, first_one=(_ == 0),
220
+ x_dim=int(self.cfg.width / self.nstack), y_dim=int(self.cfg.height / self.nstack))
221
+ for _ in range(nstack)])
222
+
223
+ self.features = nn.ModuleList([
224
+ nn.Sequential(
225
+ Block(in_channel, in_channel),
226
+ ConvBlock(in_channel, in_channel, 1, bn=True, relu=True)
227
+ ) for _ in range(nstack)])
228
+
229
+ self.out_heatmaps = nn.ModuleList(
230
+ [ConvBlock(in_channel, self.num_heats, 1, relu=False, bn=False)
231
+ for _ in range(nstack)])
232
+
233
+ if self.cfg.use_AAM:
234
+ self.num_edges = classes_num[1]
235
+ self.num_points = classes_num[2]
236
+
237
+ self.e2h_transform = E2HTransform(edge_info, self.num_points, self.num_edges)
238
+ self.out_edgemaps = nn.ModuleList(
239
+ [ConvBlock(in_channel, self.num_edges, 1, relu=False, bn=False)
240
+ for _ in range(nstack)])
241
+ self.out_pointmaps = nn.ModuleList(
242
+ [ConvBlock(in_channel, self.num_points, 1, relu=False, bn=False)
243
+ for _ in range(nstack)])
244
+ self.merge_edgemaps = nn.ModuleList(
245
+ [ConvBlock(self.num_edges, in_channel, 1, relu=False, bn=False)
246
+ for _ in range(nstack - 1)])
247
+ self.merge_pointmaps = nn.ModuleList(
248
+ [ConvBlock(self.num_points, in_channel, 1, relu=False, bn=False)
249
+ for _ in range(nstack - 1)])
250
+ self.edgemap_act = Activation("sigmoid", self.num_edges)
251
+ self.pointmap_act = Activation("sigmoid", self.num_points)
252
+
253
+ self.merge_features = nn.ModuleList(
254
+ [ConvBlock(in_channel, in_channel, 1, relu=False, bn=False)
255
+ for _ in range(nstack - 1)])
256
+ self.merge_heatmaps = nn.ModuleList(
257
+ [ConvBlock(self.num_heats, in_channel, 1, relu=False, bn=False)
258
+ for _ in range(nstack - 1)])
259
+
260
+ self.nstack = nstack
261
+
262
+ self.heatmap_act = Activation("in+relu", self.num_heats)
263
+
264
+ self.inference = False
265
+
266
+ def set_inference(self, inference):
267
+ self.inference = inference
268
+
269
+ def forward(self, x):
270
+ x = self.pre(x)
271
+
272
+ y, fusionmaps = [], []
273
+ heatmaps = None
274
+ for i in range(self.nstack):
275
+ hg = self.hgs[i](x, heatmap=heatmaps)
276
+ feature = self.features[i](hg)
277
+
278
+ heatmaps0 = self.out_heatmaps[i](feature)
279
+ heatmaps = self.heatmap_act(heatmaps0)
280
+
281
+ if self.cfg.use_AAM:
282
+ pointmaps0 = self.out_pointmaps[i](feature)
283
+ pointmaps = self.pointmap_act(pointmaps0)
284
+ edgemaps0 = self.out_edgemaps[i](feature)
285
+ edgemaps = self.edgemap_act(edgemaps0)
286
+ mask = self.e2h_transform(edgemaps) * pointmaps
287
+ fusion_heatmaps = mask * heatmaps
288
+ else:
289
+ fusion_heatmaps = heatmaps
290
+
291
+ landmarks = self.decoder.get_coords_from_heatmap(fusion_heatmaps)
292
+
293
+ if i < self.nstack - 1:
294
+ x = x + self.merge_features[i](feature) + \
295
+ self.merge_heatmaps[i](heatmaps)
296
+ if self.cfg.use_AAM:
297
+ x += self.merge_pointmaps[i](pointmaps)
298
+ x += self.merge_edgemaps[i](edgemaps)
299
+
300
+ y.append(landmarks)
301
+ if self.cfg.use_AAM:
302
+ y.append(pointmaps)
303
+ y.append(edgemaps)
304
+
305
+ fusionmaps.append(fusion_heatmaps)
306
+
307
+ return y, fusionmaps, landmarks