Upload 113 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- LICENCE +201 -0
- README.md +122 -12
- assets/ReadMe.md +44 -0
- assets/XVerseBench/.DS_Store +0 -0
- assets/XVerseBench/human/00_boy.jpg +0 -0
- assets/XVerseBench/human/01_man.jpg +0 -0
- assets/XVerseBench/human/02_man.jpg +0 -0
- assets/XVerseBench/human/03_woman.jpg +0 -0
- assets/XVerseBench/human/04_little girl.jpg +0 -0
- assets/XVerseBench/human/05_man.jpg +0 -0
- assets/XVerseBench/human/06_man.jpg +0 -0
- assets/XVerseBench/human/07_man.jpg +0 -0
- assets/XVerseBench/human/08_man.jpg +0 -0
- assets/XVerseBench/human/09_woman.jpg +0 -0
- assets/XVerseBench/human/10_man.jpg +0 -0
- assets/XVerseBench/human/11_man.jpg +0 -0
- assets/XVerseBench/human/12_woman.jpg +0 -0
- assets/XVerseBench/human/13_woman.jpg +0 -0
- assets/XVerseBench/human/14_boy.jpg +0 -0
- assets/XVerseBench/human/15_woman.jpg +0 -0
- assets/XVerseBench/human/16_old man.jpg +0 -0
- assets/XVerseBench/human/17_man.jpg +0 -0
- assets/XVerseBench/human/18_man.jpg +0 -0
- assets/XVerseBench/human/19_girl.jpg +0 -0
- assets/crop_faces.py +62 -0
- assets/rename.py +76 -0
- assets/segmentation.py +76 -0
- eval/eval_scripts/run_eval_multi.sh +48 -0
- eval/eval_scripts/run_eval_single.sh +48 -0
- eval/grounded_sam/florence2/config.json +85 -0
- eval/grounded_sam/florence2/configuration_florence2.py +340 -0
- eval/grounded_sam/florence2/generation_config.json +4 -0
- eval/grounded_sam/florence2/modeling_florence2.py +0 -0
- eval/grounded_sam/florence2/preprocessor_config.json +39 -0
- eval/grounded_sam/florence2/processing_florence2.py +1147 -0
- eval/grounded_sam/florence2/tokenizer.json +0 -0
- eval/grounded_sam/florence2/tokenizer_config.json +4 -0
- eval/grounded_sam/florence2/vocab.json +0 -0
- eval/grounded_sam/grounded_sam2_florence2_autolabel_pipeline.py +361 -0
- eval/grounded_sam/sam2/__init__.py +11 -0
- eval/grounded_sam/sam2/automatic_mask_generator.py +454 -0
- eval/grounded_sam/sam2/build_sam.py +172 -0
- eval/grounded_sam/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- eval/grounded_sam/sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- eval/grounded_sam/sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- eval/grounded_sam/sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- eval/grounded_sam/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- eval/grounded_sam/sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- eval/grounded_sam/sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
sample/first_page.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
sample/XVerseBench.png filter=lfs diff=lfs merge=lfs -text
|
LICENCE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,122 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# XVerse: Consistent Multi-Subject Control of Identity and Semantic Attributes via DiT Modulation
|
2 |
+
|
3 |
+
<p align="center">
|
4 |
+
<a href="https://arxiv.org/abs/2506.21416">
|
5 |
+
<img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-2506.21416-b31b1b.svg">
|
6 |
+
</a>
|
7 |
+
<a href="https://bytedance.github.io/XVerse/">
|
8 |
+
<img alt="Project Page" src="https://img.shields.io/badge/Project-Page-blue">
|
9 |
+
</a>
|
10 |
+
<a href="https://github.com/bytedance/XVerse/tree/main/assets">
|
11 |
+
<img alt="Build" src="https://img.shields.io/badge/XVerseBench-Dataset-green">
|
12 |
+
</a>
|
13 |
+
<a href="https://huggingface.co/ByteDance/XVerse">
|
14 |
+
<img alt="Build" src="https://img.shields.io/badge/🤗-HF%20Model-yellow">
|
15 |
+
</a>
|
16 |
+
</p>
|
17 |
+
|
18 |
+
## 🔥 News
|
19 |
+
- **2025.6.26**: The code has been released!
|
20 |
+
|
21 |
+

|
22 |
+
|
23 |
+
## 📖 Introduction
|
24 |
+
|
25 |
+
**XVerse** introduces a novel approach to multi-subject image synthesis, offering **precise and independent control over individual subjects** without disrupting the overall image latents or features. We achieve this by transforming reference images into offsets for token-specific text-stream modulation.
|
26 |
+
|
27 |
+
This innovation enables high-fidelity, editable image generation where you can robustly control both **individual subject characteristics** (identity) and their **semantic attributes**. XVerse significantly enhances capabilities for personalized and complex scene generation.
|
28 |
+
|
29 |
+
## ⚡️ Quick Start
|
30 |
+
|
31 |
+
### Requirements and Installation
|
32 |
+
|
33 |
+
First, install the necessary dependencies:
|
34 |
+
|
35 |
+
```bash
|
36 |
+
# Create a conda environment named XVerse with Python version 3.10.16
|
37 |
+
conda create -n XVerse python=3.10.16 -y
|
38 |
+
# Activate the XVerse environment
|
39 |
+
conda activate XVerse
|
40 |
+
# Use pip to install the dependencies specified in requirements.txt
|
41 |
+
pip install -r requirements.txt
|
42 |
+
```
|
43 |
+
|
44 |
+
Next, download the required checkpoints:
|
45 |
+
```bash
|
46 |
+
cd checkpoints
|
47 |
+
bash ./download_ckpts.sh
|
48 |
+
cd ..
|
49 |
+
```
|
50 |
+
**Important**: You'll also need to download the face recognition model `model_ir_se50.pth` from [InsightFace_Pytorch](https://github.com/TreB1eN/InsightFace_Pytorch) and place it directly into the `./checkpoints/` folder.
|
51 |
+
|
52 |
+
### Local Gradio Demo
|
53 |
+
|
54 |
+
To run the interactive Gradio demo locally, execute the following command:
|
55 |
+
```bash
|
56 |
+
bash run_demo.sh
|
57 |
+
```
|
58 |
+
|
59 |
+
#### Input Settings Explained
|
60 |
+
The Gradio demo provides several parameters to control your image generation process:
|
61 |
+
* **Prompt**: The textual description guiding the image generation.
|
62 |
+
* **Generated Height/Width**: Use the sliders to set the shape of the output image.
|
63 |
+
* **Weight_id/ip**: Adjust these weight parameters. Higher values generally lead to better subject consistency but might slightly impact the naturalness of the generated image.
|
64 |
+
* **latent_lora_scale and vae_lora_scale**: Control the LoRA scale. Similar to Weight_id/ip, larger LoRA values can improve subject consistency but may reduce image naturalness.
|
65 |
+
* **vae_skip_iter_before and vae_skip_iter_after**: Configure VAE skip iterations. Skipping more steps can result in better naturalness but might compromise subject consistency.
|
66 |
+
|
67 |
+
#### Input Images
|
68 |
+
|
69 |
+
The demo provides detailed control over your input images:
|
70 |
+
|
71 |
+
* **Expand Panel**: Click "Input Image X" to reveal the options for each image.
|
72 |
+
* **Upload Image**: Click "Image X" to upload your desired reference image.
|
73 |
+
* **Image Description**: Enter a description in the "Caption X" input box. You can also click "Auto Caption" to generate a description automatically.
|
74 |
+
* **Detection & Segmentation**: Click "Det & Seg" to perform detection and segmentation on the uploaded image.
|
75 |
+
* **Crop Face**: Use "Crop Face" to automatically crop the face from the image.
|
76 |
+
* **ID Checkbox**: Check or uncheck "ID or not" to determine whether to use ID-related weights for that specific input image.
|
77 |
+
|
78 |
+
> **⚠️ Important Usage Notes:**
|
79 |
+
>
|
80 |
+
> * **Prompt Construction**: The main text prompt **MUST** include the exact text you entered in the `Image Description` field for each active image. **Generation will fail if this description is missing from the prompt.**
|
81 |
+
> * *Example*: If you upload two images and set their descriptions as "a man with red hair" (for Image 1) and "a woman with blue eyes" (for Image 2), your main prompt might be: "A `a man with red hair` walking beside `a woman with blue eyes` in a park."
|
82 |
+
> * You can then write your main prompt simply as: "`ENT1` walking beside `ENT2` in a park." The code will **automatically replace** these placeholders with the full description text before generation.
|
83 |
+
> * **Active Images**: Only images in **expanded** (un-collapsed) panels will be fed into the model. Collapsed image panels are ignored.
|
84 |
+
|
85 |
+
## Inference with XVerseBench
|
86 |
+
|
87 |
+

|
88 |
+
|
89 |
+
First, please download XVerseBench according to the contents in the `assets` folder. Then, when running inference, please execute the following command:
|
90 |
+
```bash
|
91 |
+
bash ./eval/eval_scripts/run_eval.sh
|
92 |
+
```
|
93 |
+
The script will automatically evaluate the model on the XVerseBench dataset and save the results in the `./results` folder.
|
94 |
+
|
95 |
+
## 📌 ToDo
|
96 |
+
|
97 |
+
- [x] Release github repo.
|
98 |
+
- [x] Release arXiv paper.
|
99 |
+
- [x] Release model checkpoints.
|
100 |
+
- [x] Release inference data: XVerseBench.
|
101 |
+
- [x] Release inference code for XVerseBench.
|
102 |
+
- [x] Release inference code for gradio demo.
|
103 |
+
- [ ] Release inference code for single sample.
|
104 |
+
- [ ] Release huggingface space demo.
|
105 |
+
- [ ] Release Benchmark Leaderboard.
|
106 |
+
|
107 |
+
## License
|
108 |
+
|
109 |
+
The code in this project is licensed under Apache 2.0; the dataset is licensed under CC0, subject to the intellctual property owned by Bytedance. Meanwhile, the dataset is adapted from [dreambench++](https://dreambenchplus.github.io/), you should also comply with the license of dreambench++.
|
110 |
+
|
111 |
+
## Citation
|
112 |
+
If XVerse is helpful, please help to ⭐ the repo.
|
113 |
+
|
114 |
+
If you find this project useful for your research, please consider citing our paper:
|
115 |
+
```bibtex
|
116 |
+
@article{chen2025xverse,
|
117 |
+
title={XVerse: Consistent Multi-Subject Control of Identity and Semantic Attributes via DiT Modulation},
|
118 |
+
author={Chen, Bowen and Zhao, Mengyi and Sun, Haomiao and Chen, Li and Wang, Xu and Du, Kang and Wu, Xinglong},
|
119 |
+
journal={arXiv preprint arXiv:2506.21416},
|
120 |
+
year={2025}
|
121 |
+
}
|
122 |
+
```
|
assets/ReadMe.md
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Install of XVerseBench
|
2 |
+
|
3 |
+
Existing controlled image generation benchmarks often focus on either maintaining identity or object appearance consistency, rarely encompassing datasets that rigorously test both aspects. To comprehensively assess the models' single-subject and multi-subject conditional generation and editing capabilities, we constructed a new benchmark by merging and curating data from DreamBench++ and some generated human images.
|
4 |
+
|
5 |
+
Our resulting benchmark XVerseBench comprises 20 distinct human identities, 74 unique objects, and 45 different animal species/individuals. To thoroughly evaluate model effectiveness in subject-driven generation tasks, we developed test sets specifically for single-subject, dual-subject, and triple-subject control scenarios. This benchmark includes 300 unique test prompts covering diverse combinations of humans, objects, and animals.
|
6 |
+
|
7 |
+
<p align="center">
|
8 |
+
<img src="../sample/XVerseBench.png" alt="XVerseBench">
|
9 |
+
</p>
|
10 |
+
<p align="center"><strong>Figure 1. XVerseBench</strong></p>
|
11 |
+
|
12 |
+
The above figure shows more detail information and samples for each categories. For evaluation, we employ a suite of metrics to quantify different aspects of generation quality and control fidelity: including DPG score to assess the model's editing capability, Face ID similarity and DINOv2 similarity to assess the model's preservation of human identity and objects, and Aesthetic Score to measure to evaluate the aesthetics of the generated image. XVerseBench aims to provide a more challenging and holistic evaluation framework for state-of-the-art multi-subject controllable text-to-image generation models.
|
13 |
+
|
14 |
+
## Usage
|
15 |
+
|
16 |
+
1. Download **DreamBench++** from [https://dreambenchplus.github.io/](https://dreambenchplus.github.io/) and place it into the `data/DreamBench++` directory.
|
17 |
+
2. Run the following command to rename and segementate the images:
|
18 |
+
```bash
|
19 |
+
python assets/rename.py
|
20 |
+
python assets/segmentation_sample.py
|
21 |
+
```
|
22 |
+
|
23 |
+
## Citation
|
24 |
+
If XVerseBench is helpful, please help to ⭐ the repo.
|
25 |
+
|
26 |
+
If you find this project useful for your research, please consider citing our paper:
|
27 |
+
```bibtex
|
28 |
+
@article{chen2025xverse,
|
29 |
+
title={XVerse: Consistent Multi-Subject Control of Identity and Semantic Attributes via DiT Modulation},
|
30 |
+
author={Chen, Bowen and Zhao, Mengyi and Sun, Haomiao and Chen, Li and Wang, Xu and Du, Kang and Wu, Xinglong},
|
31 |
+
journal={arXiv preprint arXiv:2506.21416},
|
32 |
+
year={2025}
|
33 |
+
}
|
34 |
+
```
|
35 |
+
|
36 |
+
|
37 |
+
> Disclaimer:
|
38 |
+
>
|
39 |
+
> Your access to and use of this dataset are at your own risk. We do not guarantee the accuracy of this dataset. The dataset is provided “as is” and we make no warranty or representation to you with respect to it and we expressly disclaim, and hereby expressly waive, all warranties, express, implied, statutory or otherwise. This includes, without limitation, warranties of quality, performance, merchantability or fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable.
|
40 |
+
>
|
41 |
+
> In no event will we be liable to you on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this public license or use of the licensed material.
|
42 |
+
>
|
43 |
+
> The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
|
44 |
+
|
assets/XVerseBench/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
assets/XVerseBench/human/00_boy.jpg
ADDED
![]() |
assets/XVerseBench/human/01_man.jpg
ADDED
![]() |
assets/XVerseBench/human/02_man.jpg
ADDED
![]() |
assets/XVerseBench/human/03_woman.jpg
ADDED
![]() |
assets/XVerseBench/human/04_little girl.jpg
ADDED
![]() |
assets/XVerseBench/human/05_man.jpg
ADDED
![]() |
assets/XVerseBench/human/06_man.jpg
ADDED
![]() |
assets/XVerseBench/human/07_man.jpg
ADDED
![]() |
assets/XVerseBench/human/08_man.jpg
ADDED
![]() |
assets/XVerseBench/human/09_woman.jpg
ADDED
![]() |
assets/XVerseBench/human/10_man.jpg
ADDED
![]() |
assets/XVerseBench/human/11_man.jpg
ADDED
![]() |
assets/XVerseBench/human/12_woman.jpg
ADDED
![]() |
assets/XVerseBench/human/13_woman.jpg
ADDED
![]() |
assets/XVerseBench/human/14_boy.jpg
ADDED
![]() |
assets/XVerseBench/human/15_woman.jpg
ADDED
![]() |
assets/XVerseBench/human/16_old man.jpg
ADDED
![]() |
assets/XVerseBench/human/17_man.jpg
ADDED
![]() |
assets/XVerseBench/human/18_man.jpg
ADDED
![]() |
assets/XVerseBench/human/19_girl.jpg
ADDED
![]() |
assets/crop_faces.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import face_recognition
|
3 |
+
from PIL import Image, ImageOps
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
def detect_and_crop_faces(input_dir, output_dir):
|
7 |
+
# 确保输出目录存在
|
8 |
+
if not os.path.exists(output_dir):
|
9 |
+
os.makedirs(output_dir)
|
10 |
+
|
11 |
+
# 遍历输入目录中的所有文件
|
12 |
+
for filename in os.listdir(input_dir):
|
13 |
+
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
|
14 |
+
input_path = os.path.join(input_dir, filename)
|
15 |
+
output_path = os.path.join(output_dir, filename.replace('.png', '.jpg'))
|
16 |
+
|
17 |
+
# 加载图像并处理透明背景
|
18 |
+
image = Image.open(input_path).convert("RGBA")
|
19 |
+
background = Image.new("RGBA", image.size, "WHITE")
|
20 |
+
alpha_composite = Image.alpha_composite(background, image).convert("RGB")
|
21 |
+
|
22 |
+
# 添加白色边缘,这里 padding 设为 10 像素,可按需调整
|
23 |
+
padded_image = ImageOps.expand(alpha_composite, border=10, fill='white')
|
24 |
+
|
25 |
+
# 尝试不同尺度的图像检测
|
26 |
+
scales = [0.6, 0.4, 0.2]
|
27 |
+
face_locations = []
|
28 |
+
for scale in scales:
|
29 |
+
resized_image = padded_image.resize((int(padded_image.width * scale), int(padded_image.height * scale)), Image.LANCZOS)
|
30 |
+
image_np = np.array(resized_image)
|
31 |
+
# Use the cnn model for detection
|
32 |
+
face_locations = face_recognition.face_locations(image_np, model="cnn")
|
33 |
+
if face_locations:
|
34 |
+
# Adjust the detected face positions to the original image size
|
35 |
+
face_locations = [(int(top / scale), int(right / scale), int(bottom / scale), int(left / scale)) for top, right, bottom, left in face_locations]
|
36 |
+
break
|
37 |
+
|
38 |
+
if face_locations:
|
39 |
+
# 假设第一个检测到的人脸是需要裁剪的
|
40 |
+
top, right, bottom, left = face_locations[0]
|
41 |
+
height = bottom - top
|
42 |
+
width = right - left
|
43 |
+
|
44 |
+
# 计算扩充后的区域
|
45 |
+
new_top = max(0, int(top - height * 0.3))
|
46 |
+
new_bottom = min(np.array(padded_image).shape[0], int(bottom + height * 0.3))
|
47 |
+
new_left = max(0, int(left - width * 0.3))
|
48 |
+
new_right = min(np.array(padded_image).shape[1], int(right + width * 0.3))
|
49 |
+
|
50 |
+
face_image = np.array(padded_image)[new_top:new_bottom, new_left:new_right]
|
51 |
+
# 将 NumPy 数组转换为 PIL 图像
|
52 |
+
face_pil = Image.fromarray(face_image)
|
53 |
+
# 保存裁剪后的人脸图像
|
54 |
+
face_pil.save(output_path)
|
55 |
+
print(f"已裁剪并保存: {output_path}")
|
56 |
+
else:
|
57 |
+
print(f"未在 {input_path} 中检测到人脸")
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
input_directory = "/mnt/bn/yg-butterfly-algo/personal/sunhm/code/XVerse/assets/XVerseBench_seg/human_seg"
|
61 |
+
output_directory = "/mnt/bn/yg-butterfly-algo/personal/sunhm/code/XVerse/assets/XVerseBench_seg/human"
|
62 |
+
detect_and_crop_faces(input_directory, output_directory)
|
assets/rename.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
|
4 |
+
split = [("live_subject/animal", "animal"), ("object", "object")]
|
5 |
+
|
6 |
+
# 定义目录路径
|
7 |
+
caption_dir_base = './data/DreamBench_plus/captions'
|
8 |
+
image_dir_base = './data/DreamBench_plus/images'
|
9 |
+
new_image_dir_base = './data/XVerseBench_rename'
|
10 |
+
|
11 |
+
for s, ts in split:
|
12 |
+
caption_dir = os.path.join(caption_dir_base, s)
|
13 |
+
image_dir = os.path.join(image_dir_base, s)
|
14 |
+
new_image_dir = os.path.join(new_image_dir_base, ts)
|
15 |
+
|
16 |
+
# 创建新的目标目录(如果不存在)
|
17 |
+
if not os.path.exists(new_image_dir):
|
18 |
+
os.makedirs(new_image_dir)
|
19 |
+
|
20 |
+
# 获取所有 caption 文件
|
21 |
+
caption_files = sorted([f for f in os.listdir(caption_dir) if f.endswith('.txt')])
|
22 |
+
|
23 |
+
for caption_file in caption_files:
|
24 |
+
# 提取索引
|
25 |
+
index = os.path.splitext(caption_file)[0]
|
26 |
+
# 构建 caption 文件完整路径
|
27 |
+
caption_file_path = os.path.join(caption_dir, caption_file)
|
28 |
+
# 构建对应的图片文件路径
|
29 |
+
image_file_name = f'{index}.jpg'
|
30 |
+
image_file_path = os.path.join(image_dir, image_file_name)
|
31 |
+
|
32 |
+
# 检查图片文件是否存在
|
33 |
+
if os.path.exists(image_file_path):
|
34 |
+
# 读取 caption 文件内容
|
35 |
+
with open(caption_file_path, 'r', encoding='utf-8') as f:
|
36 |
+
caption = f.read().split('\n')[0].strip()
|
37 |
+
|
38 |
+
# 生成新的文件名
|
39 |
+
new_file_name = f'{index}_{caption}.jpg'
|
40 |
+
new_file_path_in_new_dir = os.path.join(new_image_dir, new_file_name)
|
41 |
+
|
42 |
+
# 移动并重命名文件
|
43 |
+
shutil.copy2(image_file_path, new_file_path_in_new_dir)
|
44 |
+
print(f'文件 {image_file_path} 已移动并重命名为 {new_file_path_in_new_dir}')
|
45 |
+
else:
|
46 |
+
print(f'未找到对应的图片文件: {image_file_path}')
|
47 |
+
|
48 |
+
|
49 |
+
old_human_index = ['00', '05', '06', '09', '12', '13', '14', '16', '17']
|
50 |
+
|
51 |
+
# 新增的文件映射
|
52 |
+
new_files = [
|
53 |
+
"object/65_anime space ranger.jpg", "object/66_anime girl.jpg", "object/67_pixelated warrior.jpg",
|
54 |
+
"object/68_anime girl.jpg", "object/69_anime samurai.jpg", "object/70_anime girl.jpg",
|
55 |
+
"object/71_anime Spider-Man.jpg", "object/72_Avatar.jpg", "object/73_anime man.jpg"
|
56 |
+
]
|
57 |
+
|
58 |
+
# 新增复制文件的代码
|
59 |
+
for old_human_index, new_file in zip(old_human_index, new_files):
|
60 |
+
# 构建原始图片文件路径
|
61 |
+
original_image_path = os.path.join(image_dir_base, "live_subject/human", f"{old_human_index}.jpg")
|
62 |
+
# 构建新的图片文件路径
|
63 |
+
new_image_path = os.path.join(new_image_dir_base, new_file)
|
64 |
+
|
65 |
+
# 创建新文件的目录(如果不存在)
|
66 |
+
new_image_dir = os.path.dirname(new_image_path)
|
67 |
+
if not os.path.exists(new_image_dir):
|
68 |
+
os.makedirs(new_image_dir)
|
69 |
+
|
70 |
+
# 检查原始图片文件是否存在
|
71 |
+
if os.path.exists(original_image_path):
|
72 |
+
# 复制文件
|
73 |
+
shutil.copy2(original_image_path, new_image_path)
|
74 |
+
print(f'文件 {original_image_path} 已复制到 {new_image_path}')
|
75 |
+
else:
|
76 |
+
print(f'未找到对应的图片文件: {original_image_path}')
|
assets/segmentation.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.utils.data_utils import get_train_config, image_grid, pil2tensor, json_dump, pad_to_square, cv2pil, merge_bboxes
|
2 |
+
from eval.tools.florence_sam import ObjectDetector
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
from PIL import Image # 补充导入 Image 模块
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def merge_instances(orig_img, indices, ins_bboxes, ins_images):
|
9 |
+
orig_image_width, orig_image_height = orig_img.width, orig_img.height
|
10 |
+
final_img = Image.new("RGB", (orig_image_width, orig_image_height), color=(255, 255, 255))
|
11 |
+
bboxes = []
|
12 |
+
for i in indices:
|
13 |
+
bbox = np.array(ins_bboxes[i], dtype=int).tolist()
|
14 |
+
bboxes.append(bbox)
|
15 |
+
|
16 |
+
img = cv2pil(ins_images[i])
|
17 |
+
mask = (np.array(img)[..., :3] != 255).any(axis=-1)
|
18 |
+
mask = Image.fromarray(mask.astype(np.uint8) * 255, mode='L')
|
19 |
+
final_img.paste(img, (bbox[0], bbox[1]), mask)
|
20 |
+
|
21 |
+
bbox = merge_bboxes(bboxes)
|
22 |
+
img = final_img.crop(bbox)
|
23 |
+
return img, bbox
|
24 |
+
|
25 |
+
dtype = torch.bfloat16
|
26 |
+
device = "cuda"
|
27 |
+
detector = ObjectDetector(device)
|
28 |
+
def det_seg_img(image, label):
|
29 |
+
if isinstance(image, str):
|
30 |
+
image = Image.open(image).convert("RGB")
|
31 |
+
instance_result_dict = detector.get_multiple_instances(image, label, min_size=image.size[0]//20)
|
32 |
+
indices = list(range(len(instance_result_dict["instance_images"])))
|
33 |
+
ins, bbox = merge_instances(image, indices, instance_result_dict["instance_bboxes"], instance_result_dict["instance_images"])
|
34 |
+
return ins
|
35 |
+
|
36 |
+
def segment_images_in_folder(input_folder, output_folder):
|
37 |
+
"""
|
38 |
+
对输入文件夹内所有图像进行分割,并将结果保存到输出文件夹。
|
39 |
+
|
40 |
+
:param input_folder: 输入图像文件夹路径
|
41 |
+
:param output_folder: 输出分割结果的文件夹路径
|
42 |
+
"""
|
43 |
+
# 确保输出文件夹存在
|
44 |
+
os.makedirs(output_folder, exist_ok=True)
|
45 |
+
|
46 |
+
# 遍历输入文件夹及其子文件夹内的所有文件
|
47 |
+
for root, _, filenames in os.walk(input_folder):
|
48 |
+
for filename in filenames:
|
49 |
+
# 检查是否为图像文件
|
50 |
+
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
|
51 |
+
file_path = os.path.join(root, filename)
|
52 |
+
try:
|
53 |
+
# 从文件名中提取标签,假设文件名格式为 "数字_标签.png"
|
54 |
+
label = filename.split('_')[-1].rsplit('.', 1)[0].strip()
|
55 |
+
# 进行图像分割
|
56 |
+
segmentation_result = det_seg_img(file_path, label)
|
57 |
+
# 构建输出文件路径,保持原文件名
|
58 |
+
relative_path = os.path.relpath(root, input_folder)
|
59 |
+
output_subfolder = os.path.join(output_folder, relative_path)
|
60 |
+
os.makedirs(output_subfolder, exist_ok=True)
|
61 |
+
output_path = os.path.join(output_subfolder, filename)
|
62 |
+
# 保存分割结果
|
63 |
+
if isinstance(segmentation_result, Image.Image):
|
64 |
+
segmentation_result.save(output_path)
|
65 |
+
else:
|
66 |
+
# 假设 segmentation_result 是可转换为 PIL Image 的对象
|
67 |
+
Image.fromarray(segmentation_result).save(output_path)
|
68 |
+
except Exception as e:
|
69 |
+
print(f"处理文件 {file_path} 时出错: {e}")
|
70 |
+
|
71 |
+
|
72 |
+
# 使用示例
|
73 |
+
if __name__ == "__main__":
|
74 |
+
input_folder = "./assets/XverseBench_rename"
|
75 |
+
output_folder = "./assets/XVerseBench"
|
76 |
+
segment_images_in_folder(input_folder, output_folder)
|
eval/eval_scripts/run_eval_multi.sh
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export config_path="./train/config/XVerse_config_INF.yaml"
|
2 |
+
export model_checkpoint="./checkpoints/XVerse"
|
3 |
+
export target_size=768
|
4 |
+
export condition_size=256
|
5 |
+
export test_list_name="XVerseBench_multi"
|
6 |
+
export save_name="./eval/XVerseBench_multi"
|
7 |
+
|
8 |
+
ports=(`echo $METIS_WORKER_0_PORT | tr ',' ' '`)
|
9 |
+
port=${ports[-1]}
|
10 |
+
|
11 |
+
accelerate launch \
|
12 |
+
--main_process_port $port \
|
13 |
+
-m eval.tools.idip_gen_split_idip \
|
14 |
+
--config_name "$config_path" \
|
15 |
+
--model_path "$model_checkpoint" \
|
16 |
+
--target_size "$target_size" \
|
17 |
+
--condition_size "$condition_size" \
|
18 |
+
--save_name "$save_name" \
|
19 |
+
--test_list_name "$test_list_name"
|
20 |
+
|
21 |
+
accelerate launch \
|
22 |
+
--main_process_port $port \
|
23 |
+
-m eval.tools.idip_dpg_score \
|
24 |
+
--input_dir "$save_name" \
|
25 |
+
--test_list_name "$test_list_name"
|
26 |
+
|
27 |
+
accelerate launch \
|
28 |
+
--main_process_port $port \
|
29 |
+
-m eval.tools.idip_aes_score \
|
30 |
+
--input_dir "$save_name" \
|
31 |
+
--test_list_name "$test_list_name"
|
32 |
+
|
33 |
+
accelerate launch \
|
34 |
+
--main_process_port $port \
|
35 |
+
-m eval.tools.idip_face_score \
|
36 |
+
--input_dir "$save_name" \
|
37 |
+
--test_list_name "$test_list_name"
|
38 |
+
|
39 |
+
accelerate launch \
|
40 |
+
--main_process_port $port \
|
41 |
+
-m eval.tools.idip_sam-dino_score \
|
42 |
+
--input_dir "$save_name" \
|
43 |
+
--test_list_name "$test_list_name"
|
44 |
+
|
45 |
+
python \
|
46 |
+
-m eval.tools.log_scores \
|
47 |
+
--input_dir "$save_name" \
|
48 |
+
--test_list_name "$test_list_name"
|
eval/eval_scripts/run_eval_single.sh
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export config_path="./train/config/XVerse_config_INF.yaml"
|
2 |
+
export model_checkpoint="./checkpoints/XVerse"
|
3 |
+
export target_size=768
|
4 |
+
export condition_size=256
|
5 |
+
export test_list_name="XVerseBench_single"
|
6 |
+
export save_name="./eval/XVerseBench_singleidip"
|
7 |
+
|
8 |
+
ports=(`echo $METIS_WORKER_0_PORT | tr ',' ' '`)
|
9 |
+
port=${ports[-1]}
|
10 |
+
|
11 |
+
accelerate launch \
|
12 |
+
--main_process_port $port \
|
13 |
+
-m eval.tools.idip_gen_split_idip \
|
14 |
+
--config_name "$config_path" \
|
15 |
+
--model_path "$model_checkpoint" \
|
16 |
+
--target_size "$target_size" \
|
17 |
+
--condition_size "$condition_size" \
|
18 |
+
--save_name "$save_name" \
|
19 |
+
--test_list_name "$test_list_name"
|
20 |
+
|
21 |
+
accelerate launch \
|
22 |
+
--main_process_port $port \
|
23 |
+
-m eval.tools.idip_dpg_score \
|
24 |
+
--input_dir "$save_name" \
|
25 |
+
--test_list_name "$test_list_name"
|
26 |
+
|
27 |
+
accelerate launch \
|
28 |
+
--main_process_port $port \
|
29 |
+
-m eval.tools.idip_aes_score \
|
30 |
+
--input_dir "$save_name" \
|
31 |
+
--test_list_name "$test_list_name"
|
32 |
+
|
33 |
+
accelerate launch \
|
34 |
+
--main_process_port $port \
|
35 |
+
-m eval.tools.idip_face_score \
|
36 |
+
--input_dir "$save_name" \
|
37 |
+
--test_list_name "$test_list_name"
|
38 |
+
|
39 |
+
accelerate launch \
|
40 |
+
--main_process_port $port \
|
41 |
+
-m eval.tools.idip_sam-dino_score \
|
42 |
+
--input_dir "$save_name" \
|
43 |
+
--test_list_name "$test_list_name"
|
44 |
+
|
45 |
+
python \
|
46 |
+
-m eval.tools.log_scores \
|
47 |
+
--input_dir "$save_name" \
|
48 |
+
--test_list_name "$test_list_name"
|
eval/grounded_sam/florence2/config.json
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "florence2",
|
3 |
+
"architectures": [
|
4 |
+
"Florence2ForConditionalGeneration"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "configuration_florence2.Florence2Config",
|
8 |
+
"AutoModelForCausalLM": "modeling_florence2.Florence2ForConditionalGeneration"
|
9 |
+
},
|
10 |
+
"bos_token_id": 0,
|
11 |
+
"eos_token_id": 2,
|
12 |
+
"ignore_index": -100,
|
13 |
+
"model_type": "florence2",
|
14 |
+
"pad_token_id": 1,
|
15 |
+
"projection_dim": 1024,
|
16 |
+
"text_config": {
|
17 |
+
"vocab_size": 51289,
|
18 |
+
"activation_dropout": 0.1,
|
19 |
+
"activation_function": "gelu",
|
20 |
+
"add_bias_logits": false,
|
21 |
+
"add_final_layer_norm": false,
|
22 |
+
"attention_dropout": 0.1,
|
23 |
+
"bos_token_id": 0,
|
24 |
+
"classif_dropout": 0.1,
|
25 |
+
"classifier_dropout": 0.0,
|
26 |
+
"d_model": 1024,
|
27 |
+
"decoder_attention_heads": 16,
|
28 |
+
"decoder_ffn_dim": 4096,
|
29 |
+
"decoder_layerdrop": 0.0,
|
30 |
+
"decoder_layers": 12,
|
31 |
+
"decoder_start_token_id": 2,
|
32 |
+
"dropout": 0.1,
|
33 |
+
"early_stopping": true,
|
34 |
+
"encoder_attention_heads": 16,
|
35 |
+
"encoder_ffn_dim": 4096,
|
36 |
+
"encoder_layerdrop": 0.0,
|
37 |
+
"encoder_layers": 12,
|
38 |
+
"eos_token_id": 2,
|
39 |
+
"forced_eos_token_id": 2,
|
40 |
+
"forced_bos_token_id": 0,
|
41 |
+
"gradient_checkpointing": false,
|
42 |
+
"init_std": 0.02,
|
43 |
+
"is_encoder_decoder": true,
|
44 |
+
"label2id": {
|
45 |
+
"LABEL_0": 0,
|
46 |
+
"LABEL_1": 1,
|
47 |
+
"LABEL_2": 2
|
48 |
+
},
|
49 |
+
"max_position_embeddings": 1024,
|
50 |
+
"no_repeat_ngram_size": 3,
|
51 |
+
"normalize_before": false,
|
52 |
+
"num_hidden_layers": 12,
|
53 |
+
"pad_token_id": 1,
|
54 |
+
"scale_embedding": false,
|
55 |
+
"num_beams": 3
|
56 |
+
},
|
57 |
+
"vision_config": {
|
58 |
+
"model_type": "davit",
|
59 |
+
"drop_path_rate": 0.1,
|
60 |
+
"patch_size": [7, 3, 3, 3],
|
61 |
+
"patch_stride": [4, 2, 2, 2],
|
62 |
+
"patch_padding": [3, 1, 1, 1],
|
63 |
+
"patch_prenorm": [false, true, true, true],
|
64 |
+
"enable_checkpoint": false,
|
65 |
+
"dim_embed": [256, 512, 1024, 2048],
|
66 |
+
"num_heads": [8, 16, 32, 64],
|
67 |
+
"num_groups": [8, 16, 32, 64],
|
68 |
+
"depths": [1, 1, 9, 1],
|
69 |
+
"window_size": 12,
|
70 |
+
"projection_dim": 1024,
|
71 |
+
"visual_temporal_embedding": {
|
72 |
+
"type": "COSINE",
|
73 |
+
"max_temporal_embeddings": 100
|
74 |
+
},
|
75 |
+
"image_pos_embed": {
|
76 |
+
"type": "learned_abs_2d",
|
77 |
+
"max_pos_embeddings": 50
|
78 |
+
},
|
79 |
+
"image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"]
|
80 |
+
},
|
81 |
+
"vocab_size": 51289,
|
82 |
+
"torch_dtype": "float16",
|
83 |
+
"transformers_version": "4.41.0.dev0",
|
84 |
+
"is_encoder_decoder": true
|
85 |
+
}
|
eval/grounded_sam/florence2/configuration_florence2.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import warnings
|
15 |
+
""" Florence-2 configuration"""
|
16 |
+
|
17 |
+
from typing import Optional
|
18 |
+
|
19 |
+
from transformers import AutoConfig
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.utils import logging
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__)
|
24 |
+
|
25 |
+
class Florence2VisionConfig(PretrainedConfig):
|
26 |
+
r"""
|
27 |
+
This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
|
28 |
+
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
29 |
+
defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
|
30 |
+
|
31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
32 |
+
documentation from [`PretrainedConfig`] for more information.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
drop_path_rate (`float`, *optional*, defaults to 0.1):
|
36 |
+
The dropout rate of the drop path layer.
|
37 |
+
patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
|
38 |
+
The patch size of the image.
|
39 |
+
patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
|
40 |
+
The patch stride of the image.
|
41 |
+
patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
|
42 |
+
The patch padding of the image.
|
43 |
+
patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
|
44 |
+
Whether to apply layer normalization before the patch embedding layer.
|
45 |
+
enable_checkpoint (`bool`, *optional*, defaults to False):
|
46 |
+
Whether to enable checkpointing.
|
47 |
+
dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
|
48 |
+
The dimension of the embedding layer.
|
49 |
+
num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
|
50 |
+
The number of attention heads.
|
51 |
+
num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
|
52 |
+
The number of groups.
|
53 |
+
depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
|
54 |
+
The depth of the model.
|
55 |
+
window_size (`int`, *optional*, defaults to 12):
|
56 |
+
The window size of the model.
|
57 |
+
projection_dim (`int`, *optional*, defaults to 1024):
|
58 |
+
The dimension of the projection layer.
|
59 |
+
visual_temporal_embedding (`dict`, *optional*):
|
60 |
+
The configuration of the visual temporal embedding.
|
61 |
+
image_pos_embed (`dict`, *optional*):
|
62 |
+
The configuration of the image position embedding.
|
63 |
+
image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
|
64 |
+
The source of the image feature.
|
65 |
+
Example:
|
66 |
+
|
67 |
+
```python
|
68 |
+
>>> from transformers import Florence2VisionConfig, Florence2VisionModel
|
69 |
+
|
70 |
+
>>> # Initializing a Florence2 Vision style configuration
|
71 |
+
>>> configuration = Florence2VisionConfig()
|
72 |
+
|
73 |
+
>>> # Initializing a model (with random weights)
|
74 |
+
>>> model = Florence2VisionModel(configuration)
|
75 |
+
|
76 |
+
>>> # Accessing the model configuration
|
77 |
+
>>> configuration = model.config
|
78 |
+
```"""
|
79 |
+
|
80 |
+
model_type = "florence2_vision"
|
81 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
82 |
+
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
drop_path_rate=0.1,
|
86 |
+
patch_size=[7, 3, 3, 3],
|
87 |
+
patch_stride=[4, 2, 2, 2],
|
88 |
+
patch_padding=[3, 1, 1, 1],
|
89 |
+
patch_prenorm=[False, True, True, True],
|
90 |
+
enable_checkpoint=False,
|
91 |
+
dim_embed=[256, 512, 1024, 2048],
|
92 |
+
num_heads=[8, 16, 32, 64],
|
93 |
+
num_groups=[8, 16, 32, 64],
|
94 |
+
depths=[1, 1, 9, 1],
|
95 |
+
window_size=12,
|
96 |
+
projection_dim=1024,
|
97 |
+
visual_temporal_embedding=None,
|
98 |
+
image_pos_embed=None,
|
99 |
+
image_feature_source=["spatial_avg_pool", "temporal_avg_pool"],
|
100 |
+
**kwargs,
|
101 |
+
):
|
102 |
+
self.drop_path_rate = drop_path_rate
|
103 |
+
self.patch_size = patch_size
|
104 |
+
self.patch_stride = patch_stride
|
105 |
+
self.patch_padding = patch_padding
|
106 |
+
self.patch_prenorm = patch_prenorm
|
107 |
+
self.enable_checkpoint = enable_checkpoint
|
108 |
+
self.dim_embed = dim_embed
|
109 |
+
self.num_heads = num_heads
|
110 |
+
self.num_groups = num_groups
|
111 |
+
self.depths = depths
|
112 |
+
self.window_size = window_size
|
113 |
+
self.projection_dim = projection_dim
|
114 |
+
self.visual_temporal_embedding = visual_temporal_embedding
|
115 |
+
self.image_pos_embed = image_pos_embed
|
116 |
+
self.image_feature_source = image_feature_source
|
117 |
+
|
118 |
+
super().__init__(**kwargs)
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
class Florence2LanguageConfig(PretrainedConfig):
|
123 |
+
r"""
|
124 |
+
This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
|
125 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
126 |
+
defaults will yield a similar configuration to that of the BART
|
127 |
+
[facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
|
128 |
+
|
129 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
130 |
+
documentation from [`PretrainedConfig`] for more information.
|
131 |
+
|
132 |
+
|
133 |
+
Args:
|
134 |
+
vocab_size (`int`, *optional*, defaults to 51289):
|
135 |
+
Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
|
136 |
+
`inputs_ids` passed when calling [`Florence2LanguageModel`].
|
137 |
+
d_model (`int`, *optional*, defaults to 1024):
|
138 |
+
Dimensionality of the layers and the pooler layer.
|
139 |
+
encoder_layers (`int`, *optional*, defaults to 12):
|
140 |
+
Number of encoder layers.
|
141 |
+
decoder_layers (`int`, *optional*, defaults to 12):
|
142 |
+
Number of decoder layers.
|
143 |
+
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
144 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
145 |
+
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
146 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
147 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
148 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
149 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
150 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
151 |
+
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
152 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
153 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
154 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
155 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
156 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
157 |
+
The dropout ratio for the attention probabilities.
|
158 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
159 |
+
The dropout ratio for activations inside the fully connected layer.
|
160 |
+
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
161 |
+
The dropout ratio for classifier.
|
162 |
+
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
163 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
164 |
+
just in case (e.g., 512 or 1024 or 2048).
|
165 |
+
init_std (`float`, *optional*, defaults to 0.02):
|
166 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
167 |
+
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
168 |
+
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
169 |
+
for more details.
|
170 |
+
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
171 |
+
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
172 |
+
for more details.
|
173 |
+
scale_embedding (`bool`, *optional*, defaults to `False`):
|
174 |
+
Scale embeddings by diving by sqrt(d_model).
|
175 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
176 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
177 |
+
num_labels (`int`, *optional*, defaults to 3):
|
178 |
+
The number of labels to use in [`Florence2LanguageForSequenceClassification`].
|
179 |
+
forced_eos_token_id (`int`, *optional*, defaults to 2):
|
180 |
+
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
|
181 |
+
`eos_token_id`.
|
182 |
+
|
183 |
+
Example:
|
184 |
+
|
185 |
+
```python
|
186 |
+
>>> from transformers import Florence2LanguageConfig, Florence2LanguageModel
|
187 |
+
|
188 |
+
>>> # Initializing a Florence2 Language style configuration
|
189 |
+
>>> configuration = Florence2LanguageConfig()
|
190 |
+
|
191 |
+
>>> # Initializing a model (with random weights)
|
192 |
+
>>> model = Florence2LangaugeModel(configuration)
|
193 |
+
|
194 |
+
>>> # Accessing the model configuration
|
195 |
+
>>> configuration = model.config
|
196 |
+
```"""
|
197 |
+
|
198 |
+
model_type = "florence2_language"
|
199 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
200 |
+
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
201 |
+
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
vocab_size=51289,
|
205 |
+
max_position_embeddings=1024,
|
206 |
+
encoder_layers=12,
|
207 |
+
encoder_ffn_dim=4096,
|
208 |
+
encoder_attention_heads=16,
|
209 |
+
decoder_layers=12,
|
210 |
+
decoder_ffn_dim=4096,
|
211 |
+
decoder_attention_heads=16,
|
212 |
+
encoder_layerdrop=0.0,
|
213 |
+
decoder_layerdrop=0.0,
|
214 |
+
activation_function="gelu",
|
215 |
+
d_model=1024,
|
216 |
+
dropout=0.1,
|
217 |
+
attention_dropout=0.0,
|
218 |
+
activation_dropout=0.0,
|
219 |
+
init_std=0.02,
|
220 |
+
classifier_dropout=0.0,
|
221 |
+
scale_embedding=False,
|
222 |
+
use_cache=True,
|
223 |
+
num_labels=3,
|
224 |
+
pad_token_id=1,
|
225 |
+
bos_token_id=0,
|
226 |
+
eos_token_id=2,
|
227 |
+
is_encoder_decoder=True,
|
228 |
+
decoder_start_token_id=2,
|
229 |
+
forced_eos_token_id=2,
|
230 |
+
**kwargs,
|
231 |
+
):
|
232 |
+
self.vocab_size = vocab_size
|
233 |
+
self.max_position_embeddings = max_position_embeddings
|
234 |
+
self.d_model = d_model
|
235 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
236 |
+
self.encoder_layers = encoder_layers
|
237 |
+
self.encoder_attention_heads = encoder_attention_heads
|
238 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
239 |
+
self.decoder_layers = decoder_layers
|
240 |
+
self.decoder_attention_heads = decoder_attention_heads
|
241 |
+
self.dropout = dropout
|
242 |
+
self.attention_dropout = attention_dropout
|
243 |
+
self.activation_dropout = activation_dropout
|
244 |
+
self.activation_function = activation_function
|
245 |
+
self.init_std = init_std
|
246 |
+
self.encoder_layerdrop = encoder_layerdrop
|
247 |
+
self.decoder_layerdrop = decoder_layerdrop
|
248 |
+
self.classifier_dropout = classifier_dropout
|
249 |
+
self.use_cache = use_cache
|
250 |
+
self.num_hidden_layers = encoder_layers
|
251 |
+
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
252 |
+
|
253 |
+
super().__init__(
|
254 |
+
num_labels=num_labels,
|
255 |
+
pad_token_id=pad_token_id,
|
256 |
+
bos_token_id=bos_token_id,
|
257 |
+
eos_token_id=eos_token_id,
|
258 |
+
is_encoder_decoder=is_encoder_decoder,
|
259 |
+
decoder_start_token_id=decoder_start_token_id,
|
260 |
+
forced_eos_token_id=forced_eos_token_id,
|
261 |
+
**kwargs,
|
262 |
+
)
|
263 |
+
|
264 |
+
# ensure backward compatibility for BART CNN models
|
265 |
+
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
266 |
+
self.forced_bos_token_id = self.bos_token_id
|
267 |
+
warnings.warn(
|
268 |
+
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
|
269 |
+
"The config can simply be saved and uploaded again to be fixed."
|
270 |
+
)
|
271 |
+
|
272 |
+
class Florence2Config(PretrainedConfig):
|
273 |
+
r"""
|
274 |
+
This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
|
275 |
+
Florence-2 model according to the specified arguments, defining the model architecture.
|
276 |
+
|
277 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
278 |
+
documentation from [`PretrainedConfig`] for more information.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
vision_config (`Florence2VisionConfig`, *optional*):
|
282 |
+
Custom vision config or dict
|
283 |
+
text_config (`Union[AutoConfig, dict]`, *optional*):
|
284 |
+
The config object of the text backbone.
|
285 |
+
ignore_index (`int`, *optional*, defaults to -100):
|
286 |
+
The ignore index for the loss function.
|
287 |
+
vocab_size (`int`, *optional*, defaults to 51289):
|
288 |
+
Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
|
289 |
+
`inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
|
290 |
+
projection_dim (`int`, *optional*, defaults to 1024):
|
291 |
+
Dimension of the multimodal projection space.
|
292 |
+
|
293 |
+
Example:
|
294 |
+
|
295 |
+
```python
|
296 |
+
>>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig
|
297 |
+
|
298 |
+
>>> # Initializing a clip-like vision config
|
299 |
+
>>> vision_config = CLIPVisionConfig()
|
300 |
+
|
301 |
+
>>> # Initializing a Bart config
|
302 |
+
>>> text_config = BartConfig()
|
303 |
+
|
304 |
+
>>> # Initializing a Florence-2 configuration
|
305 |
+
>>> configuration = Florence2Config(vision_config, text_config)
|
306 |
+
|
307 |
+
>>> # Initializing a model from the florence-2 configuration
|
308 |
+
>>> model = Florence2ForConditionalGeneration(configuration)
|
309 |
+
|
310 |
+
>>> # Accessing the model configuration
|
311 |
+
>>> configuration = model.config
|
312 |
+
```"""
|
313 |
+
|
314 |
+
model_type = "florence2"
|
315 |
+
is_composition = False
|
316 |
+
|
317 |
+
def __init__(
|
318 |
+
self,
|
319 |
+
vision_config=None,
|
320 |
+
text_config=None,
|
321 |
+
ignore_index=-100,
|
322 |
+
vocab_size=51289,
|
323 |
+
projection_dim=1024,
|
324 |
+
**kwargs,
|
325 |
+
):
|
326 |
+
self.ignore_index = ignore_index
|
327 |
+
self.vocab_size = vocab_size
|
328 |
+
self.projection_dim = projection_dim
|
329 |
+
if vision_config is not None:
|
330 |
+
vision_config = PretrainedConfig(**vision_config)
|
331 |
+
self.vision_config = vision_config
|
332 |
+
self.vocab_size = self.vocab_size
|
333 |
+
|
334 |
+
self.text_config = text_config
|
335 |
+
if text_config is not None:
|
336 |
+
self.text_config = Florence2LanguageConfig(**text_config)
|
337 |
+
|
338 |
+
|
339 |
+
super().__init__(**kwargs)
|
340 |
+
|
eval/grounded_sam/florence2/generation_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"num_beams": 3,
|
3 |
+
"early_stopping": false
|
4 |
+
}
|
eval/grounded_sam/florence2/modeling_florence2.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
eval/grounded_sam/florence2/preprocessor_config.json
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoProcessor": "processing_florence2.Florence2Processor"
|
4 |
+
},
|
5 |
+
"_valid_processor_keys": [
|
6 |
+
"images",
|
7 |
+
"do_resize",
|
8 |
+
"size",
|
9 |
+
"resample",
|
10 |
+
"do_rescale",
|
11 |
+
"rescale_factor",
|
12 |
+
"do_normalize",
|
13 |
+
"image_mean",
|
14 |
+
"image_std",
|
15 |
+
"return_tensors",
|
16 |
+
"data_format",
|
17 |
+
"input_data_format",
|
18 |
+
"do_convert_rgb"
|
19 |
+
],
|
20 |
+
"do_convert_rgb": null,
|
21 |
+
"do_normalize": true,
|
22 |
+
"do_rescale": true,
|
23 |
+
"do_resize": true,
|
24 |
+
"do_center_crop": false,
|
25 |
+
"image_processor_type": "CLIPImageProcessor",
|
26 |
+
"image_seq_length": 577,
|
27 |
+
"image_mean": [0.485, 0.456, 0.406],
|
28 |
+
"image_std": [0.229, 0.224, 0.225],
|
29 |
+
"processor_class": "Florence2Processor",
|
30 |
+
"resample": 3,
|
31 |
+
"size": {
|
32 |
+
"height": 768,
|
33 |
+
"width":768
|
34 |
+
},
|
35 |
+
"crop_size": {
|
36 |
+
"height": 768,
|
37 |
+
"width": 768
|
38 |
+
}
|
39 |
+
}
|
eval/grounded_sam/florence2/processing_florence2.py
ADDED
@@ -0,0 +1,1147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 Microsoft and The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Processor class for Florence-2.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import re
|
20 |
+
import logging
|
21 |
+
from typing import List, Optional, Union
|
22 |
+
import numpy as np
|
23 |
+
import math
|
24 |
+
|
25 |
+
import torch
|
26 |
+
|
27 |
+
from transformers.feature_extraction_utils import BatchFeature
|
28 |
+
from transformers.image_utils import ImageInput, is_valid_image
|
29 |
+
from transformers.processing_utils import ProcessorMixin
|
30 |
+
from transformers.tokenization_utils_base import (
|
31 |
+
PaddingStrategy,
|
32 |
+
PreTokenizedInput,
|
33 |
+
TextInput,
|
34 |
+
TruncationStrategy,
|
35 |
+
)
|
36 |
+
from transformers import BartTokenizer, BartTokenizerFast
|
37 |
+
from transformers.utils import TensorType
|
38 |
+
|
39 |
+
|
40 |
+
logger = logging.getLogger(__name__)
|
41 |
+
|
42 |
+
# Copied from transformers.models.idefics2.processing_idefics2.is_url
|
43 |
+
def is_url(val) -> bool:
|
44 |
+
return isinstance(val, str) and val.startswith("http")
|
45 |
+
|
46 |
+
# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
|
47 |
+
def is_image_or_image_url(elem):
|
48 |
+
return is_url(elem) or is_valid_image(elem)
|
49 |
+
|
50 |
+
|
51 |
+
def _is_str_or_image(elem):
|
52 |
+
return isinstance(elem, (str)) or is_image_or_image_url(elem)
|
53 |
+
|
54 |
+
|
55 |
+
class Florence2Processor(ProcessorMixin):
|
56 |
+
r"""
|
57 |
+
Constructs a Florence2 processor which wraps a Florence2 image processor and a Florence2 tokenizer into a single processor.
|
58 |
+
|
59 |
+
[`Florence2Processor`] offers all the functionalities of [`CLIPImageProcessor`] and [`BartTokenizerFast`]. See the
|
60 |
+
[`~Florence2Processor.__call__`] and [`~Florence2Processor.decode`] for more information.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
image_processor ([`CLIPImageProcessor`], *optional*):
|
64 |
+
The image processor is a required input.
|
65 |
+
tokenizer ([`BartTokenizerFast`], *optional*):
|
66 |
+
The tokenizer is a required input.
|
67 |
+
"""
|
68 |
+
|
69 |
+
attributes = ["image_processor", "tokenizer"]
|
70 |
+
image_processor_class = "CLIPImageProcessor"
|
71 |
+
tokenizer_class = ("BartTokenizer", "BartTokenizerFast")
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
image_processor=None,
|
76 |
+
tokenizer=None,
|
77 |
+
):
|
78 |
+
if image_processor is None:
|
79 |
+
raise ValueError("You need to specify an `image_processor`.")
|
80 |
+
if tokenizer is None:
|
81 |
+
raise ValueError("You need to specify a `tokenizer`.")
|
82 |
+
if not hasattr(image_processor, "image_seq_length"):
|
83 |
+
raise ValueError("Image processor is missing an `image_seq_length` attribute.")
|
84 |
+
|
85 |
+
self.image_seq_length = image_processor.image_seq_length
|
86 |
+
|
87 |
+
tokens_to_add = {
|
88 |
+
'additional_special_tokens': \
|
89 |
+
tokenizer.additional_special_tokens + \
|
90 |
+
['<od>', '</od>', '<ocr>', '</ocr>'] + \
|
91 |
+
[f'<loc_{x}>' for x in range(1000)] + \
|
92 |
+
['<cap>', '</cap>', '<ncap>', '</ncap>','<dcap>', '</dcap>', '<grounding>', '</grounding>', '<seg>', '</seg>', '<sep>', '<region_cap>', '</region_cap>', '<region_to_desciption>', '</region_to_desciption>', '<proposal>', '</proposal>', '<poly>', '</poly>', '<and>']
|
93 |
+
}
|
94 |
+
tokenizer.add_special_tokens(tokens_to_add)
|
95 |
+
|
96 |
+
self.tasks_answer_post_processing_type = {
|
97 |
+
'<OCR>': 'pure_text',
|
98 |
+
'<OCR_WITH_REGION>': 'ocr',
|
99 |
+
'<CAPTION>': 'pure_text',
|
100 |
+
'<DETAILED_CAPTION>': 'pure_text',
|
101 |
+
'<MORE_DETAILED_CAPTION>': 'pure_text',
|
102 |
+
'<OD>': 'description_with_bboxes',
|
103 |
+
'<DENSE_REGION_CAPTION>': 'description_with_bboxes',
|
104 |
+
'<CAPTION_TO_PHRASE_GROUNDING>': "phrase_grounding",
|
105 |
+
'<REFERRING_EXPRESSION_SEGMENTATION>': 'polygons',
|
106 |
+
'<REGION_TO_SEGMENTATION>': 'polygons',
|
107 |
+
'<OPEN_VOCABULARY_DETECTION>': 'description_with_bboxes_or_polygons',
|
108 |
+
'<REGION_TO_CATEGORY>': 'pure_text',
|
109 |
+
'<REGION_TO_DESCRIPTION>': 'pure_text',
|
110 |
+
'<REGION_TO_OCR>': 'pure_text',
|
111 |
+
'<REGION_PROPOSAL>': 'bboxes'
|
112 |
+
}
|
113 |
+
|
114 |
+
self.task_prompts_without_inputs = {
|
115 |
+
'<OCR>': 'What is the text in the image?',
|
116 |
+
'<OCR_WITH_REGION>': 'What is the text in the image, with regions?',
|
117 |
+
'<CAPTION>': 'What does the image describe?',
|
118 |
+
'<DETAILED_CAPTION>': 'Describe in detail what is shown in the image.',
|
119 |
+
'<MORE_DETAILED_CAPTION>': 'Describe with a paragraph what is shown in the image.',
|
120 |
+
'<OD>': 'Locate the objects with category name in the image.',
|
121 |
+
'<DENSE_REGION_CAPTION>': 'Locate the objects in the image, with their descriptions.',
|
122 |
+
'<REGION_PROPOSAL>': 'Locate the region proposals in the image.'
|
123 |
+
}
|
124 |
+
|
125 |
+
self.task_prompts_with_input = {
|
126 |
+
'<CAPTION_TO_PHRASE_GROUNDING>': "Locate the phrases in the caption: {input}",
|
127 |
+
'<REFERRING_EXPRESSION_SEGMENTATION>': 'Locate {input} in the image with mask',
|
128 |
+
'<REGION_TO_SEGMENTATION>': 'What is the polygon mask of region {input}',
|
129 |
+
'<OPEN_VOCABULARY_DETECTION>': 'Locate {input} in the image.',
|
130 |
+
'<REGION_TO_CATEGORY>': 'What is the region {input}?',
|
131 |
+
'<REGION_TO_DESCRIPTION>': 'What does the region {input} describe?',
|
132 |
+
'<REGION_TO_OCR>': 'What text is in the region {input}?',
|
133 |
+
}
|
134 |
+
|
135 |
+
self.post_processor = Florence2PostProcesser(tokenizer=tokenizer)
|
136 |
+
|
137 |
+
|
138 |
+
super().__init__(image_processor, tokenizer)
|
139 |
+
|
140 |
+
def _construct_prompts(self, text):
|
141 |
+
# replace the task tokens with the task prompts if task token is in the text
|
142 |
+
prompts = []
|
143 |
+
for _text in text:
|
144 |
+
# 1. fixed task prompts without additional inputs
|
145 |
+
for task_token, task_prompt in self.task_prompts_without_inputs.items():
|
146 |
+
if task_token in _text:
|
147 |
+
assert _text == task_token, f"Task token {task_token} should be the only token in the text."
|
148 |
+
_text = task_prompt
|
149 |
+
break
|
150 |
+
# 2. task prompts with additional inputs
|
151 |
+
for task_token, task_prompt in self.task_prompts_with_input.items():
|
152 |
+
if task_token in _text:
|
153 |
+
_text = task_prompt.format(input=_text.replace(task_token, ''))
|
154 |
+
break
|
155 |
+
prompts.append(_text)
|
156 |
+
return prompts
|
157 |
+
|
158 |
+
def __call__(
|
159 |
+
self,
|
160 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
161 |
+
images: ImageInput = None,
|
162 |
+
tokenize_newline_separately: bool = True,
|
163 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
164 |
+
truncation: Union[bool, str, TruncationStrategy] = None,
|
165 |
+
max_length=None,
|
166 |
+
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
167 |
+
do_resize: bool = None,
|
168 |
+
do_normalize: bool = None,
|
169 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
170 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
171 |
+
data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
|
172 |
+
input_data_format: Optional[
|
173 |
+
Union[str, "ChannelDimension"] # noqa: F821
|
174 |
+
] = None,
|
175 |
+
resample: "PILImageResampling" = None, # noqa: F821
|
176 |
+
do_convert_rgb: bool = None,
|
177 |
+
do_thumbnail: bool = None,
|
178 |
+
do_align_long_axis: bool = None,
|
179 |
+
do_rescale: bool = None,
|
180 |
+
) -> BatchFeature:
|
181 |
+
"""
|
182 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
183 |
+
and `kwargs` arguments to BartTokenizerFast's [`~BartTokenizerFast.__call__`] if `text` is not `None` to encode
|
184 |
+
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
185 |
+
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
|
186 |
+
of the above two methods for more information.
|
187 |
+
|
188 |
+
Args:
|
189 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
190 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
191 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
192 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
193 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
194 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
195 |
+
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
196 |
+
number of channels, H and W are image height and width.
|
197 |
+
tokenize_newline_separately (`bool`, defaults to `True`):
|
198 |
+
Adds a separately tokenized '\n' at the end of the prompt.
|
199 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
|
200 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
201 |
+
index) among:
|
202 |
+
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
203 |
+
sequence if provided).
|
204 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
205 |
+
acceptable input length for the model if that argument is not provided.
|
206 |
+
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
207 |
+
lengths).
|
208 |
+
max_length (`int`, *optional*):
|
209 |
+
Maximum length of the returned list and optionally padding length (see above).
|
210 |
+
truncation (`bool`, *optional*):
|
211 |
+
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
212 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
213 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
214 |
+
|
215 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
216 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
217 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
218 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
222 |
+
|
223 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. If `suffix`
|
224 |
+
is provided, the `input_ids` will also contain the suffix input ids.
|
225 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
226 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
227 |
+
`None`).
|
228 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
229 |
+
- **labels** -- Labels compatible with training if `suffix` is not None
|
230 |
+
"""
|
231 |
+
|
232 |
+
return_token_type_ids = False
|
233 |
+
|
234 |
+
if images is None:
|
235 |
+
raise ValueError("`images` are expected as arguments to a `Florence2Processor` instance.")
|
236 |
+
if text is None:
|
237 |
+
logger.warning_once(
|
238 |
+
"You are using Florence-2 without a text prompt."
|
239 |
+
)
|
240 |
+
text = ""
|
241 |
+
|
242 |
+
if isinstance(text, List) and isinstance(images, List):
|
243 |
+
if len(images) < len(text):
|
244 |
+
raise ValueError(
|
245 |
+
f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image."
|
246 |
+
)
|
247 |
+
if _is_str_or_image(text):
|
248 |
+
text = [text]
|
249 |
+
elif isinstance(text, list) and _is_str_or_image(text[0]):
|
250 |
+
pass
|
251 |
+
|
252 |
+
pixel_values = self.image_processor(
|
253 |
+
images,
|
254 |
+
do_resize=do_resize,
|
255 |
+
do_normalize=do_normalize,
|
256 |
+
return_tensors=return_tensors,
|
257 |
+
image_mean=image_mean,
|
258 |
+
image_std=image_std,
|
259 |
+
input_data_format=input_data_format,
|
260 |
+
data_format=data_format,
|
261 |
+
resample=resample,
|
262 |
+
do_convert_rgb=do_convert_rgb,
|
263 |
+
)["pixel_values"]
|
264 |
+
|
265 |
+
if max_length is not None:
|
266 |
+
max_length -= self.image_seq_length # max_length has to account for the image tokens
|
267 |
+
|
268 |
+
text = self._construct_prompts(text)
|
269 |
+
|
270 |
+
inputs = self.tokenizer(
|
271 |
+
text,
|
272 |
+
return_tensors=return_tensors,
|
273 |
+
padding=padding,
|
274 |
+
max_length=max_length,
|
275 |
+
truncation=truncation,
|
276 |
+
return_token_type_ids=return_token_type_ids,
|
277 |
+
)
|
278 |
+
|
279 |
+
return_data = {**inputs, "pixel_values": pixel_values}
|
280 |
+
|
281 |
+
if return_token_type_ids:
|
282 |
+
labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
|
283 |
+
return_data.update({"labels": labels})
|
284 |
+
return BatchFeature(data=return_data)
|
285 |
+
|
286 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Florence2
|
287 |
+
def batch_decode(self, *args, **kwargs):
|
288 |
+
"""
|
289 |
+
This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
290 |
+
refer to the docstring of this method for more information.
|
291 |
+
"""
|
292 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
293 |
+
|
294 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Florence2
|
295 |
+
def decode(self, *args, **kwargs):
|
296 |
+
"""
|
297 |
+
This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
298 |
+
the docstring of this method for more information.
|
299 |
+
"""
|
300 |
+
return self.tokenizer.decode(*args, **kwargs)
|
301 |
+
|
302 |
+
@property
|
303 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Florence2
|
304 |
+
def model_input_names(self):
|
305 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
306 |
+
image_processor_input_names = self.image_processor.model_input_names
|
307 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
308 |
+
|
309 |
+
def post_process_generation(self, text=None, sequence=None, transition_beam_score=None, task=None, image_size=None):
|
310 |
+
"""
|
311 |
+
Post-process the output of the model to each of the task outputs.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
text (`str`): The text to post-process.
|
315 |
+
task (`str`): The task to post-process the text for.
|
316 |
+
image_size (`Tuple[int, int]`): The size of the image. height x width.
|
317 |
+
"""
|
318 |
+
|
319 |
+
task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
|
320 |
+
task_answer = self.post_processor(
|
321 |
+
text=text,
|
322 |
+
sequence=sequence,
|
323 |
+
transition_beam_score=transition_beam_score,
|
324 |
+
image_size=image_size,
|
325 |
+
parse_tasks=task_answer_post_processing_type,
|
326 |
+
)[task_answer_post_processing_type]
|
327 |
+
|
328 |
+
if task_answer_post_processing_type == 'pure_text':
|
329 |
+
final_answer = task_answer
|
330 |
+
# remove the special tokens
|
331 |
+
final_answer = final_answer.replace('<s>', '').replace('</s>', '')
|
332 |
+
elif task_answer_post_processing_type in ['od', 'description_with_bboxes', 'bboxes']:
|
333 |
+
od_instances = task_answer
|
334 |
+
bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
|
335 |
+
labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
|
336 |
+
final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
|
337 |
+
if len(od_instances) and 'score' in od_instances[0]:
|
338 |
+
scores_od = [_od_instance['score'] for _od_instance in od_instances]
|
339 |
+
final_answer['scores'] = scores_od
|
340 |
+
elif task_answer_post_processing_type in ['ocr']:
|
341 |
+
bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
|
342 |
+
labels = [str(_od_instance['text']) for _od_instance in task_answer]
|
343 |
+
final_answer = {'quad_boxes': bboxes, 'labels': labels}
|
344 |
+
elif task_answer_post_processing_type in ['phrase_grounding']:
|
345 |
+
bboxes = []
|
346 |
+
labels = []
|
347 |
+
for _grounded_phrase in task_answer:
|
348 |
+
for _bbox in _grounded_phrase['bbox']:
|
349 |
+
bboxes.append(_bbox)
|
350 |
+
labels.append(_grounded_phrase['cat_name'])
|
351 |
+
final_answer = {'bboxes': bboxes, 'labels': labels}
|
352 |
+
elif task_answer_post_processing_type in ['description_with_polygons', 'polygons']:
|
353 |
+
labels = []
|
354 |
+
polygons = []
|
355 |
+
for result in task_answer:
|
356 |
+
label = result['cat_name']
|
357 |
+
_polygons = result['polygons']
|
358 |
+
labels.append(label)
|
359 |
+
polygons.append(_polygons)
|
360 |
+
final_answer = {'polygons': polygons, 'labels': labels}
|
361 |
+
elif task_answer_post_processing_type in ['description_with_bboxes_or_polygons']:
|
362 |
+
bboxes = []
|
363 |
+
bboxes_labels = []
|
364 |
+
polygons = []
|
365 |
+
polygons_labels = []
|
366 |
+
for result in task_answer:
|
367 |
+
label = result['cat_name']
|
368 |
+
if 'polygons' in result:
|
369 |
+
_polygons = result['polygons']
|
370 |
+
polygons.append(_polygons)
|
371 |
+
polygons_labels.append(label)
|
372 |
+
else:
|
373 |
+
_bbox = result['bbox']
|
374 |
+
bboxes.append(_bbox)
|
375 |
+
bboxes_labels.append(label)
|
376 |
+
final_answer = {'bboxes': bboxes, 'bboxes_labels': bboxes_labels, 'polygons': polygons, 'polygons_labels': polygons_labels}
|
377 |
+
else:
|
378 |
+
raise ValueError('Unknown task answer post processing type: {}'.format(task_answer_post_processing_type))
|
379 |
+
|
380 |
+
final_answer = {
|
381 |
+
task: final_answer}
|
382 |
+
return final_answer
|
383 |
+
|
384 |
+
class BoxQuantizer(object):
|
385 |
+
def __init__(self, mode, bins):
|
386 |
+
self.mode = mode
|
387 |
+
self.bins = bins
|
388 |
+
|
389 |
+
def quantize(self, boxes: torch.Tensor, size):
|
390 |
+
bins_w, bins_h = self.bins # Quantization bins.
|
391 |
+
size_w, size_h = size # Original image size.
|
392 |
+
size_per_bin_w = size_w / bins_w
|
393 |
+
size_per_bin_h = size_h / bins_h
|
394 |
+
xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1].
|
395 |
+
|
396 |
+
if self.mode == 'floor':
|
397 |
+
quantized_xmin = (
|
398 |
+
xmin / size_per_bin_w).floor().clamp(0, bins_w - 1)
|
399 |
+
quantized_ymin = (
|
400 |
+
ymin / size_per_bin_h).floor().clamp(0, bins_h - 1)
|
401 |
+
quantized_xmax = (
|
402 |
+
xmax / size_per_bin_w).floor().clamp(0, bins_w - 1)
|
403 |
+
quantized_ymax = (
|
404 |
+
ymax / size_per_bin_h).floor().clamp(0, bins_h - 1)
|
405 |
+
|
406 |
+
elif self.mode == 'round':
|
407 |
+
raise NotImplementedError()
|
408 |
+
|
409 |
+
else:
|
410 |
+
raise ValueError('Incorrect quantization type.')
|
411 |
+
|
412 |
+
quantized_boxes = torch.cat(
|
413 |
+
(quantized_xmin, quantized_ymin, quantized_xmax, quantized_ymax), dim=-1
|
414 |
+
).int()
|
415 |
+
|
416 |
+
return quantized_boxes
|
417 |
+
|
418 |
+
def dequantize(self, boxes: torch.Tensor, size):
|
419 |
+
bins_w, bins_h = self.bins # Quantization bins.
|
420 |
+
size_w, size_h = size # Original image size.
|
421 |
+
size_per_bin_w = size_w / bins_w
|
422 |
+
size_per_bin_h = size_h / bins_h
|
423 |
+
xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1].
|
424 |
+
|
425 |
+
if self.mode == 'floor':
|
426 |
+
# Add 0.5 to use the center position of the bin as the coordinate.
|
427 |
+
dequantized_xmin = (xmin + 0.5) * size_per_bin_w
|
428 |
+
dequantized_ymin = (ymin + 0.5) * size_per_bin_h
|
429 |
+
dequantized_xmax = (xmax + 0.5) * size_per_bin_w
|
430 |
+
dequantized_ymax = (ymax + 0.5) * size_per_bin_h
|
431 |
+
|
432 |
+
elif self.mode == 'round':
|
433 |
+
raise NotImplementedError()
|
434 |
+
|
435 |
+
else:
|
436 |
+
raise ValueError('Incorrect quantization type.')
|
437 |
+
|
438 |
+
dequantized_boxes = torch.cat(
|
439 |
+
(dequantized_xmin, dequantized_ymin,
|
440 |
+
dequantized_xmax, dequantized_ymax), dim=-1
|
441 |
+
)
|
442 |
+
|
443 |
+
return dequantized_boxes
|
444 |
+
|
445 |
+
|
446 |
+
class CoordinatesQuantizer(object):
|
447 |
+
"""
|
448 |
+
Quantize coornidates (Nx2)
|
449 |
+
"""
|
450 |
+
|
451 |
+
def __init__(self, mode, bins):
|
452 |
+
self.mode = mode
|
453 |
+
self.bins = bins
|
454 |
+
|
455 |
+
def quantize(self, coordinates: torch.Tensor, size):
|
456 |
+
bins_w, bins_h = self.bins # Quantization bins.
|
457 |
+
size_w, size_h = size # Original image size.
|
458 |
+
size_per_bin_w = size_w / bins_w
|
459 |
+
size_per_bin_h = size_h / bins_h
|
460 |
+
assert coordinates.shape[-1] == 2, 'coordinates should be shape (N, 2)'
|
461 |
+
x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1].
|
462 |
+
|
463 |
+
if self.mode == 'floor':
|
464 |
+
quantized_x = (x / size_per_bin_w).floor().clamp(0, bins_w - 1)
|
465 |
+
quantized_y = (y / size_per_bin_h).floor().clamp(0, bins_h - 1)
|
466 |
+
|
467 |
+
elif self.mode == 'round':
|
468 |
+
raise NotImplementedError()
|
469 |
+
|
470 |
+
else:
|
471 |
+
raise ValueError('Incorrect quantization type.')
|
472 |
+
|
473 |
+
quantized_coordinates = torch.cat(
|
474 |
+
(quantized_x, quantized_y), dim=-1
|
475 |
+
).int()
|
476 |
+
|
477 |
+
return quantized_coordinates
|
478 |
+
|
479 |
+
def dequantize(self, coordinates: torch.Tensor, size):
|
480 |
+
bins_w, bins_h = self.bins # Quantization bins.
|
481 |
+
size_w, size_h = size # Original image size.
|
482 |
+
size_per_bin_w = size_w / bins_w
|
483 |
+
size_per_bin_h = size_h / bins_h
|
484 |
+
assert coordinates.shape[-1] == 2, 'coordinates should be shape (N, 2)'
|
485 |
+
x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1].
|
486 |
+
|
487 |
+
if self.mode == 'floor':
|
488 |
+
# Add 0.5 to use the center position of the bin as the coordinate.
|
489 |
+
dequantized_x = (x + 0.5) * size_per_bin_w
|
490 |
+
dequantized_y = (y + 0.5) * size_per_bin_h
|
491 |
+
|
492 |
+
elif self.mode == 'round':
|
493 |
+
raise NotImplementedError()
|
494 |
+
|
495 |
+
else:
|
496 |
+
raise ValueError('Incorrect quantization type.')
|
497 |
+
|
498 |
+
dequantized_coordinates = torch.cat(
|
499 |
+
(dequantized_x, dequantized_y), dim=-1
|
500 |
+
)
|
501 |
+
|
502 |
+
return dequantized_coordinates
|
503 |
+
|
504 |
+
|
505 |
+
class Florence2PostProcesser(object):
|
506 |
+
r"""
|
507 |
+
Florence-2 post process for converting text prediction to various tasks results.
|
508 |
+
|
509 |
+
Args:
|
510 |
+
config: A dict of configs.
|
511 |
+
tokenizer: A tokenizer for decoding text to spans.
|
512 |
+
sample config:
|
513 |
+
UNIFIED_POST_PROCESS:
|
514 |
+
# commom configs
|
515 |
+
NUM_BBOX_HEIGHT_BINS: 1000
|
516 |
+
NUM_BBOX_WIDTH_BINS: 1000
|
517 |
+
COORDINATES_HEIGHT_BINS: 1000
|
518 |
+
COORDINATES_WIDTH_BINS: 1000
|
519 |
+
# task specific configs, override the common configs
|
520 |
+
PRASE_TASKS:
|
521 |
+
- TASK_NAME: 'video_dense_caption'
|
522 |
+
PATTERN: 'r<time_(\d+)><time_(\d+)>([a-zA-Z0-9 ]+)'
|
523 |
+
SCORE_MODE: 'avg_cat_name_scores'
|
524 |
+
NUM_BINS: 100
|
525 |
+
- TASK_NAME: 'od'
|
526 |
+
PATTERN: 'r<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>([a-zA-Z0-9 ]+)'
|
527 |
+
SCORE_MODE: 'avg_cat_name_scores'
|
528 |
+
|
529 |
+
Returns:
|
530 |
+
parsed_dict (dict): A dict of parsed results.
|
531 |
+
"""
|
532 |
+
def __init__(
|
533 |
+
self,
|
534 |
+
tokenizer=None
|
535 |
+
):
|
536 |
+
parse_tasks = []
|
537 |
+
parse_task_configs = {}
|
538 |
+
config = self._create_default_config()
|
539 |
+
for task in config['PARSE_TASKS']:
|
540 |
+
parse_tasks.append(task['TASK_NAME'])
|
541 |
+
parse_task_configs[task['TASK_NAME']] = task
|
542 |
+
|
543 |
+
self.config = config
|
544 |
+
self.parse_tasks = parse_tasks
|
545 |
+
self.parse_tasks_configs = parse_task_configs
|
546 |
+
|
547 |
+
self.tokenizer = tokenizer
|
548 |
+
if self.tokenizer is not None:
|
549 |
+
self.all_special_tokens = set(self.tokenizer.all_special_tokens)
|
550 |
+
|
551 |
+
self.init_quantizers()
|
552 |
+
self.black_list_of_phrase_grounding = self._create_black_list_of_phrase_grounding()
|
553 |
+
|
554 |
+
def _create_black_list_of_phrase_grounding(self):
|
555 |
+
black_list = {}
|
556 |
+
|
557 |
+
if 'phrase_grounding' in self.parse_tasks and self.parse_tasks_configs['phrase_grounding']['FILTER_BY_BLACK_LIST']:
|
558 |
+
black_list = set(
|
559 |
+
['it', 'I', 'me', 'mine',
|
560 |
+
'you', 'your', 'yours',
|
561 |
+
'he', 'him', 'his',
|
562 |
+
'she', 'her', 'hers',
|
563 |
+
'they', 'them', 'their', 'theirs',
|
564 |
+
'one', 'oneself',
|
565 |
+
'we', 'us', 'our', 'ours',
|
566 |
+
'you', 'your', 'yours',
|
567 |
+
'they', 'them', 'their', 'theirs',
|
568 |
+
'mine', 'yours', 'his', 'hers', 'its',
|
569 |
+
'ours', 'yours', 'theirs',
|
570 |
+
'myself', 'yourself', 'himself', 'herself', 'itself',
|
571 |
+
'ourselves', 'yourselves', 'themselves',
|
572 |
+
'this', 'that',
|
573 |
+
'these', 'those',
|
574 |
+
'who', 'whom', 'whose', 'which', 'what',
|
575 |
+
'who', 'whom', 'whose', 'which', 'that',
|
576 |
+
'all', 'another', 'any', 'anybody', 'anyone', 'anything',
|
577 |
+
'each', 'everybody', 'everyone', 'everything',
|
578 |
+
'few', 'many', 'nobody', 'none', 'one', 'several',
|
579 |
+
'some', 'somebody', 'someone', 'something',
|
580 |
+
'each other', 'one another',
|
581 |
+
'myself', 'yourself', 'himself', 'herself', 'itself',
|
582 |
+
'ourselves', 'yourselves', 'themselves',
|
583 |
+
'the image', 'image', 'images', 'the', 'a', 'an', 'a group',
|
584 |
+
'other objects', 'lots', 'a set',
|
585 |
+
]
|
586 |
+
)
|
587 |
+
|
588 |
+
return black_list
|
589 |
+
|
590 |
+
def _create_default_config(self):
|
591 |
+
config = {
|
592 |
+
'NUM_BBOX_HEIGHT_BINS': 1000,
|
593 |
+
'NUM_BBOX_WIDTH_BINS': 1000,
|
594 |
+
'BOX_QUANTIZATION_MODE': 'floor',
|
595 |
+
'COORDINATES_HEIGHT_BINS': 1000,
|
596 |
+
'COORDINATES_WIDTH_BINS': 1000,
|
597 |
+
'COORDINATES_QUANTIZATION_MODE': 'floor',
|
598 |
+
'PARSE_TASKS': [
|
599 |
+
{
|
600 |
+
'TASK_NAME': 'od',
|
601 |
+
'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>',
|
602 |
+
'SCORE_MODE': 'avg_loc_scores'
|
603 |
+
},
|
604 |
+
{
|
605 |
+
'TASK_NAME': 'ocr',
|
606 |
+
'PATTERN': r'(.+?)<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>',
|
607 |
+
'AREA_THRESHOLD': 0.00
|
608 |
+
},
|
609 |
+
{
|
610 |
+
'TASK_NAME': 'phrase_grounding',
|
611 |
+
'FILTER_BY_BLACK_LIST': True
|
612 |
+
},
|
613 |
+
{
|
614 |
+
'TASK_NAME': 'pure_text',
|
615 |
+
},
|
616 |
+
{
|
617 |
+
'TASK_NAME': 'description_with_bboxes',
|
618 |
+
'SCORE_MODE': 'avg_loc_scores'
|
619 |
+
},
|
620 |
+
{
|
621 |
+
'TASK_NAME': 'description_with_polygons',
|
622 |
+
},
|
623 |
+
{
|
624 |
+
'TASK_NAME': 'polygons',
|
625 |
+
},
|
626 |
+
{
|
627 |
+
'TASK_NAME': 'bboxes',
|
628 |
+
},
|
629 |
+
{
|
630 |
+
'TASK_NAME': 'description_with_bboxes_or_polygons',
|
631 |
+
}
|
632 |
+
]
|
633 |
+
}
|
634 |
+
|
635 |
+
return config
|
636 |
+
|
637 |
+
def init_quantizers(self):
|
638 |
+
# we have box_quantizer (od, grounding) and coordinates_quantizer (ocr, referring_segmentation)
|
639 |
+
num_bbox_height_bins = self.config.get('NUM_BBOX_HEIGHT_BINS', 1000)
|
640 |
+
num_bbox_width_bins = self.config.get('NUM_BBOX_WIDTH_BINS', 1000)
|
641 |
+
box_quantization_mode = self.config.get('BOX_QUANTIZATION_MODE', 'floor')
|
642 |
+
self.box_quantizer = BoxQuantizer(
|
643 |
+
box_quantization_mode,
|
644 |
+
(num_bbox_width_bins, num_bbox_height_bins),
|
645 |
+
)
|
646 |
+
|
647 |
+
num_bbox_height_bins = self.config['COORDINATES_HEIGHT_BINS'] if 'COORDINATES_HEIGHT_BINS' in self.config else self.config.get('NUM_BBOX_HEIGHT_BINS', 1000)
|
648 |
+
num_bbox_width_bins = self.config['COORDINATES_WIDTH_BINS'] if 'COORDINATES_WIDTH_BINS' in self.config else self.config.get('NUM_BBOX_WIDTH_BINS', 1000)
|
649 |
+
box_quantization_mode = self.config.get('COORDINATES_QUANTIZATION_MODE') if 'COORDINATES_QUANTIZATION_MODE' in self.config else self.config.get('BOX_QUANTIZATION_MODE', 'floor')
|
650 |
+
self.coordinates_quantizer = CoordinatesQuantizer(
|
651 |
+
box_quantization_mode,
|
652 |
+
(num_bbox_width_bins, num_bbox_height_bins),
|
653 |
+
)
|
654 |
+
|
655 |
+
def decode_with_spans(self, tokenizer, token_ids):
|
656 |
+
filtered_tokens = tokenizer.convert_ids_to_tokens(
|
657 |
+
token_ids, skip_special_tokens=False)
|
658 |
+
assert len(filtered_tokens) == len(token_ids)
|
659 |
+
sub_texts = []
|
660 |
+
for token in filtered_tokens:
|
661 |
+
if token in self.all_special_tokens:
|
662 |
+
sub_texts.append(token)
|
663 |
+
else:
|
664 |
+
if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
|
665 |
+
sub_text = tokenizer.convert_tokens_to_string([token])
|
666 |
+
else:
|
667 |
+
raise ValueError(f'type {type(tokenizer)} not supported')
|
668 |
+
sub_texts.append(sub_text)
|
669 |
+
|
670 |
+
text = ''
|
671 |
+
spans = []
|
672 |
+
for sub_text in sub_texts:
|
673 |
+
span = (len(text), len(text) + len(sub_text)) # [start index, end index).
|
674 |
+
text += sub_text
|
675 |
+
spans.append(span)
|
676 |
+
return text, spans
|
677 |
+
|
678 |
+
def parse_od_from_text_and_spans(
|
679 |
+
self,
|
680 |
+
text,
|
681 |
+
pattern,
|
682 |
+
image_size,
|
683 |
+
phrase_centric=False
|
684 |
+
):
|
685 |
+
parsed = list(re.finditer(pattern, text))
|
686 |
+
|
687 |
+
instances = []
|
688 |
+
for i in range(len(parsed)):
|
689 |
+
# Prepare instance.
|
690 |
+
instance = {}
|
691 |
+
|
692 |
+
if phrase_centric:
|
693 |
+
bbox_bins = [int(parsed[i].group(j)) for j in range(2, 6)]
|
694 |
+
else:
|
695 |
+
bbox_bins = [int(parsed[i].group(j)) for j in range(1, 5)]
|
696 |
+
instance['bbox'] = self.box_quantizer.dequantize(
|
697 |
+
boxes=torch.tensor(bbox_bins),
|
698 |
+
size=image_size
|
699 |
+
).tolist()
|
700 |
+
|
701 |
+
if phrase_centric:
|
702 |
+
instance['cat_name'] = parsed[i].group(1).lower().strip()
|
703 |
+
else:
|
704 |
+
instance['cat_name'] = parsed[i].group(5).lower().strip()
|
705 |
+
instances.append(instance)
|
706 |
+
|
707 |
+
return instances
|
708 |
+
|
709 |
+
def parse_ocr_from_text_and_spans(self,
|
710 |
+
text,
|
711 |
+
pattern,
|
712 |
+
image_size,
|
713 |
+
area_threshold=-1.0,
|
714 |
+
):
|
715 |
+
bboxes = []
|
716 |
+
labels = []
|
717 |
+
text = text.replace('<s>', '')
|
718 |
+
# ocr with regions
|
719 |
+
parsed = re.findall(pattern, text)
|
720 |
+
instances = []
|
721 |
+
image_width, image_height = image_size
|
722 |
+
|
723 |
+
for ocr_line in parsed:
|
724 |
+
ocr_content = ocr_line[0]
|
725 |
+
quad_box = ocr_line[1:]
|
726 |
+
quad_box = [int(i) for i in quad_box]
|
727 |
+
quad_box = self.coordinates_quantizer.dequantize(
|
728 |
+
torch.tensor(np.array(quad_box).reshape(-1, 2)),
|
729 |
+
size=image_size
|
730 |
+
).reshape(-1).tolist()
|
731 |
+
|
732 |
+
if area_threshold > 0:
|
733 |
+
x_coords = [i for i in quad_box[0::2]]
|
734 |
+
y_coords = [i for i in quad_box[1::2]]
|
735 |
+
|
736 |
+
# apply the Shoelace formula
|
737 |
+
area = 0.5 * abs(sum(x_coords[i] * y_coords[i + 1] - x_coords[i + 1] * y_coords[i] for i in range(4 - 1)))
|
738 |
+
|
739 |
+
if area < (image_width * image_height) * area_threshold:
|
740 |
+
continue
|
741 |
+
|
742 |
+
bboxes.append(quad_box)
|
743 |
+
labels.append(ocr_content)
|
744 |
+
instances.append({
|
745 |
+
'quad_box': quad_box,
|
746 |
+
'text': ocr_content,
|
747 |
+
})
|
748 |
+
return instances
|
749 |
+
|
750 |
+
def parse_phrase_grounding_from_text_and_spans(self, text, pattern, image_size):
|
751 |
+
# ignore <s> </s> and <pad>
|
752 |
+
cur_span = 0
|
753 |
+
if text.startswith('<s>'):
|
754 |
+
cur_span += 3
|
755 |
+
|
756 |
+
text = text.replace('<s>', '')
|
757 |
+
text = text.replace('</s>', '')
|
758 |
+
text = text.replace('<pad>', '')
|
759 |
+
|
760 |
+
pattern = r"([^<]+(?:<loc_\d+>){4,})"
|
761 |
+
phrases = re.findall(pattern, text)
|
762 |
+
|
763 |
+
# pattern should be text pattern and od pattern
|
764 |
+
pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_)'
|
765 |
+
box_pattern = r'<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>'
|
766 |
+
|
767 |
+
instances = []
|
768 |
+
for pharse_text in phrases:
|
769 |
+
phrase_text_strip = pharse_text.replace('<ground>', '', 1)
|
770 |
+
phrase_text_strip = pharse_text.replace('<obj>', '', 1)
|
771 |
+
|
772 |
+
if phrase_text_strip == '':
|
773 |
+
cur_span += len(pharse_text)
|
774 |
+
continue
|
775 |
+
|
776 |
+
# Prepare instance.
|
777 |
+
instance = {}
|
778 |
+
|
779 |
+
# parse phrase, get string
|
780 |
+
phrase = re.search(pattern, phrase_text_strip)
|
781 |
+
if phrase is None:
|
782 |
+
cur_span += len(pharse_text)
|
783 |
+
continue
|
784 |
+
|
785 |
+
# parse bboxes by box_pattern
|
786 |
+
bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
|
787 |
+
if len(bboxes_parsed) == 0:
|
788 |
+
cur_span += len(pharse_text)
|
789 |
+
continue
|
790 |
+
|
791 |
+
phrase = phrase.group()
|
792 |
+
# remove leading and trailing spaces
|
793 |
+
phrase = phrase.strip()
|
794 |
+
|
795 |
+
if phrase in self.black_list_of_phrase_grounding:
|
796 |
+
cur_span += len(pharse_text)
|
797 |
+
continue
|
798 |
+
|
799 |
+
# a list of list
|
800 |
+
bbox_bins = [[int(_bboxes_parsed.group(j)) for j in range(1, 5)] for _bboxes_parsed in bboxes_parsed]
|
801 |
+
instance['bbox'] = self.box_quantizer.dequantize(
|
802 |
+
boxes=torch.tensor(bbox_bins),
|
803 |
+
size=image_size
|
804 |
+
).tolist()
|
805 |
+
|
806 |
+
# exclude non-ascii characters
|
807 |
+
phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
|
808 |
+
instance['cat_name'] = phrase
|
809 |
+
|
810 |
+
instances.append(instance)
|
811 |
+
|
812 |
+
return instances
|
813 |
+
|
814 |
+
def parse_description_with_bboxes_from_text_and_spans(
|
815 |
+
self,
|
816 |
+
text,
|
817 |
+
spans=None,
|
818 |
+
scores=None,
|
819 |
+
score_mode=None,
|
820 |
+
pattern=None,
|
821 |
+
image_size=None,
|
822 |
+
allow_empty_phrase=False
|
823 |
+
):
|
824 |
+
def find_matched_token_indices(cur_span, token_spans):
|
825 |
+
inds = []
|
826 |
+
for i, token_span in enumerate(token_spans):
|
827 |
+
if not (token_span[1] <= cur_span[0] or token_span[0] >= cur_span[1]):
|
828 |
+
inds.append(i)
|
829 |
+
return inds
|
830 |
+
|
831 |
+
cur_span = 0
|
832 |
+
if text.startswith('<s>'):
|
833 |
+
cur_span += 3
|
834 |
+
|
835 |
+
text = text.replace('<s>', '')
|
836 |
+
text = text.replace('</s>', '')
|
837 |
+
text = text.replace('<pad>', '')
|
838 |
+
|
839 |
+
if allow_empty_phrase:
|
840 |
+
pattern = rf"(?:(?:<loc_\d+>){{4,}})"
|
841 |
+
else:
|
842 |
+
pattern = r"([^<]+(?:<loc_\d+>){4,})"
|
843 |
+
phrases = re.findall(pattern, text)
|
844 |
+
|
845 |
+
# pattern should be text pattern and od pattern
|
846 |
+
pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_)'
|
847 |
+
box_pattern = r'<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>'
|
848 |
+
|
849 |
+
instances = []
|
850 |
+
for pharse_text in phrases:
|
851 |
+
phrase_text_strip = pharse_text.replace('<ground>', '', 1)
|
852 |
+
phrase_text_strip = pharse_text.replace('<obj>', '', 1)
|
853 |
+
|
854 |
+
if phrase_text_strip == '' and not allow_empty_phrase:
|
855 |
+
cur_span += len(pharse_text)
|
856 |
+
continue
|
857 |
+
|
858 |
+
# parse phrase, get string
|
859 |
+
phrase = re.search(pattern, phrase_text_strip)
|
860 |
+
if phrase is None:
|
861 |
+
cur_span += len(pharse_text)
|
862 |
+
continue
|
863 |
+
|
864 |
+
phrase_span = phrase.span()
|
865 |
+
phrase = phrase.group()
|
866 |
+
# remove leading and trailing spaces
|
867 |
+
phrase = phrase.strip()
|
868 |
+
|
869 |
+
# parse bboxes by box_pattern
|
870 |
+
bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
|
871 |
+
if len(bboxes_parsed) == 0:
|
872 |
+
cur_span += len(pharse_text)
|
873 |
+
continue
|
874 |
+
|
875 |
+
# a list of list
|
876 |
+
bbox_bins = [[int(_bboxes_parsed.group(j)) for j in range(1, 5)] for _bboxes_parsed in bboxes_parsed]
|
877 |
+
|
878 |
+
bboxes = self.box_quantizer.dequantize(
|
879 |
+
boxes=torch.tensor(bbox_bins),
|
880 |
+
size=image_size
|
881 |
+
).tolist()
|
882 |
+
|
883 |
+
if score_mode == 'avg_loc_scores':
|
884 |
+
if spans is None or scores is None:
|
885 |
+
all_scores = None
|
886 |
+
else:
|
887 |
+
bbox_end_spans = [_bboxes_parsed.span(0) for _bboxes_parsed in bboxes_parsed]
|
888 |
+
all_scores = []
|
889 |
+
for _spans in bbox_end_spans:
|
890 |
+
token_inds = find_matched_token_indices((_spans[0] + cur_span, _spans[1]+ cur_span), spans)
|
891 |
+
loc_scores = [scores[token_i] for token_i in token_inds]
|
892 |
+
score = sum(loc_scores) / len(loc_scores)
|
893 |
+
all_scores.append(score)
|
894 |
+
elif score_mode == 'avg_cat_name_scores':
|
895 |
+
if spans is None or scores is None:
|
896 |
+
all_scores = None
|
897 |
+
else:
|
898 |
+
cat_name_token_inds = find_matched_token_indices((phrase_span[0] + cur_span, phrase_span[1]+cur_span), spans)
|
899 |
+
cat_name_scores = [scores[token_i] for token_i in cat_name_token_inds]
|
900 |
+
score = sum(cat_name_scores) / len(cat_name_scores)
|
901 |
+
all_scores = [score] * len(bboxes)
|
902 |
+
elif score_mode is None:
|
903 |
+
all_scores = None
|
904 |
+
else:
|
905 |
+
raise ValueError('Unknown score mode: {}'.format(score_mode))
|
906 |
+
|
907 |
+
phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
|
908 |
+
for _idx, _bboxes in enumerate(bboxes):
|
909 |
+
# Prepare instance.
|
910 |
+
instance = {}
|
911 |
+
instance['bbox'] = _bboxes
|
912 |
+
# exclude non-ascii characters
|
913 |
+
instance['cat_name'] = phrase
|
914 |
+
if all_scores is not None:
|
915 |
+
instance['score'] = math.exp(all_scores[_idx])
|
916 |
+
instances.append(instance)
|
917 |
+
|
918 |
+
cur_span += len(pharse_text)
|
919 |
+
|
920 |
+
return instances
|
921 |
+
|
922 |
+
def parse_description_with_polygons_from_text_and_spans(self, text, pattern, image_size,
|
923 |
+
allow_empty_phrase=False,
|
924 |
+
polygon_sep_token='<sep>',
|
925 |
+
polygon_start_token='<poly>',
|
926 |
+
polygon_end_token='</poly>',
|
927 |
+
with_box_at_start=False,
|
928 |
+
):
|
929 |
+
|
930 |
+
# ref_seg format: '<expression><x1><y1><x2><y2><><><sep><><><><>'
|
931 |
+
# ignore <s> </s> and <pad>
|
932 |
+
|
933 |
+
text = text.replace('<s>', '')
|
934 |
+
text = text.replace('</s>', '')
|
935 |
+
text = text.replace('<pad>', '')
|
936 |
+
|
937 |
+
if allow_empty_phrase:
|
938 |
+
pattern = rf"(?:(?:<loc_\d+>|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})"
|
939 |
+
else:
|
940 |
+
# [^<]+: This part matches one or more characters that are not the < symbol.
|
941 |
+
# The ^ inside the square brackets [] is a negation, meaning it matches anything except <.
|
942 |
+
#
|
943 |
+
pattern = rf"([^<]+(?:<loc_\d+>|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})"
|
944 |
+
phrases = re.findall(pattern, text)
|
945 |
+
|
946 |
+
phrase_string_pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_|<poly>)'
|
947 |
+
box_pattern = rf'((?:<loc_\d+>)+)(?:{re.escape(polygon_sep_token)}|$)'
|
948 |
+
|
949 |
+
# one polygons instance is separated by polygon_start_token and polygon_end_token
|
950 |
+
polygons_instance_pattern = rf'{re.escape(polygon_start_token)}(.*?){re.escape(polygon_end_token)}'
|
951 |
+
|
952 |
+
instances = []
|
953 |
+
for phrase_text in phrases:
|
954 |
+
|
955 |
+
# exclude loc_\d+>
|
956 |
+
# need to get span if want to include category score
|
957 |
+
phrase_text_strip = re.sub(r'^loc_\d+>', '', phrase_text, count=1)
|
958 |
+
|
959 |
+
# phrase = phrase.replace('<poly>', '')
|
960 |
+
# phrase = phrase.replace('poly>', '')
|
961 |
+
|
962 |
+
if phrase_text_strip == '' and not allow_empty_phrase:
|
963 |
+
continue
|
964 |
+
|
965 |
+
|
966 |
+
# parse phrase, get string
|
967 |
+
phrase = re.search(phrase_string_pattern, phrase_text_strip)
|
968 |
+
if phrase is None:
|
969 |
+
continue
|
970 |
+
phrase = phrase.group()
|
971 |
+
# remove leading and trailing spaces
|
972 |
+
phrase = phrase.strip()
|
973 |
+
|
974 |
+
# parse bboxes by box_pattern
|
975 |
+
|
976 |
+
# split by polygon_start_token and polygon_end_token first using polygons_instance_pattern
|
977 |
+
if polygon_start_token in phrase_text and polygon_end_token in phrase_text:
|
978 |
+
polygons_instances_parsed = list(re.finditer(polygons_instance_pattern, phrase_text))
|
979 |
+
else:
|
980 |
+
polygons_instances_parsed = [phrase_text]
|
981 |
+
|
982 |
+
for _polygons_instances_parsed in polygons_instances_parsed:
|
983 |
+
# Prepare instance.
|
984 |
+
instance = {}
|
985 |
+
|
986 |
+
# polygons_parsed= list(re.finditer(box_pattern, phrase_text))
|
987 |
+
if isinstance(_polygons_instances_parsed, str):
|
988 |
+
polygons_parsed= list(re.finditer(box_pattern, _polygons_instances_parsed))
|
989 |
+
else:
|
990 |
+
polygons_parsed= list(re.finditer(box_pattern, _polygons_instances_parsed.group(1)))
|
991 |
+
if len(polygons_parsed) == 0:
|
992 |
+
continue
|
993 |
+
|
994 |
+
# a list of list (polygon)
|
995 |
+
bbox = []
|
996 |
+
polygons = []
|
997 |
+
for _polygon_parsed in polygons_parsed:
|
998 |
+
# group 1: whole <loc_\d+>...</loc_\d+>
|
999 |
+
_polygon = _polygon_parsed.group(1)
|
1000 |
+
# parse into list of int
|
1001 |
+
_polygon = [int(_loc_parsed.group(1)) for _loc_parsed in re.finditer(r'<loc_(\d+)>', _polygon)]
|
1002 |
+
if with_box_at_start and len(bbox) == 0:
|
1003 |
+
if len(_polygon) > 4:
|
1004 |
+
# no valid bbox prediction
|
1005 |
+
bbox = _polygon[:4]
|
1006 |
+
_polygon = _polygon[4:]
|
1007 |
+
else:
|
1008 |
+
bbox = [0, 0, 0, 0]
|
1009 |
+
# abandon last element if is not paired
|
1010 |
+
if len(_polygon) % 2 == 1:
|
1011 |
+
_polygon = _polygon[:-1]
|
1012 |
+
|
1013 |
+
# reshape into (n, 2)
|
1014 |
+
_polygon = self.coordinates_quantizer.dequantize(
|
1015 |
+
torch.tensor(np.array(_polygon).reshape(-1, 2)),
|
1016 |
+
size=image_size
|
1017 |
+
).reshape(-1).tolist()
|
1018 |
+
# reshape back
|
1019 |
+
polygons.append(_polygon)
|
1020 |
+
|
1021 |
+
instance['cat_name'] = phrase
|
1022 |
+
instance['polygons'] = polygons
|
1023 |
+
if len(bbox) != 0:
|
1024 |
+
instance['bbox'] = self.box_quantizer.dequantize(
|
1025 |
+
boxes=torch.tensor([bbox]),
|
1026 |
+
size=image_size
|
1027 |
+
).tolist()[0]
|
1028 |
+
|
1029 |
+
instances.append(instance)
|
1030 |
+
|
1031 |
+
return instances
|
1032 |
+
|
1033 |
+
def __call__(
|
1034 |
+
self,
|
1035 |
+
text=None,
|
1036 |
+
sequence=None,
|
1037 |
+
transition_beam_score=None,
|
1038 |
+
image_size=None,
|
1039 |
+
parse_tasks=None,
|
1040 |
+
):
|
1041 |
+
"""
|
1042 |
+
Args:
|
1043 |
+
text: model outputs
|
1044 |
+
image_size: (width, height)
|
1045 |
+
parse_tasks: a list of tasks to parse, if None, parse all tasks.
|
1046 |
+
|
1047 |
+
"""
|
1048 |
+
if parse_tasks is not None:
|
1049 |
+
if isinstance(parse_tasks, str):
|
1050 |
+
parse_tasks = [parse_tasks]
|
1051 |
+
for _parse_task in parse_tasks:
|
1052 |
+
assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
|
1053 |
+
|
1054 |
+
# sequence or text should be provided
|
1055 |
+
assert sequence is not None or text is not None, 'sequence or text should be provided'
|
1056 |
+
assert sequence is None or text is None, 'only one of sequence and text should be provided'
|
1057 |
+
|
1058 |
+
if sequence is not None:
|
1059 |
+
sequence = sequence.tolist()[1:]
|
1060 |
+
text, spans = self.decode_with_spans(self.tokenizer, sequence)
|
1061 |
+
if transition_beam_score is not None:
|
1062 |
+
transition_beam_score = transition_beam_score.tolist()
|
1063 |
+
assert len(sequence) == len(transition_beam_score)
|
1064 |
+
else:
|
1065 |
+
spans = None
|
1066 |
+
transition_beam_score = None
|
1067 |
+
|
1068 |
+
parsed_dict = {
|
1069 |
+
'text': text
|
1070 |
+
}
|
1071 |
+
|
1072 |
+
for task in self.parse_tasks:
|
1073 |
+
if parse_tasks is not None and task not in parse_tasks:
|
1074 |
+
continue
|
1075 |
+
|
1076 |
+
pattern = self.parse_tasks_configs[task].get('PATTERN', None)
|
1077 |
+
score_mode = self.parse_tasks_configs[task].get('SCORE_MODE', None)
|
1078 |
+
|
1079 |
+
if task == 'ocr':
|
1080 |
+
instances = self.parse_ocr_from_text_and_spans(
|
1081 |
+
text,
|
1082 |
+
pattern=pattern,
|
1083 |
+
image_size=image_size,
|
1084 |
+
area_threshold=self.parse_tasks_configs[task].get('AREA_THRESHOLD', 0.0),
|
1085 |
+
)
|
1086 |
+
parsed_dict['ocr'] = instances
|
1087 |
+
elif task == 'phrase_grounding':
|
1088 |
+
instances = self.parse_phrase_grounding_from_text_and_spans(
|
1089 |
+
text,
|
1090 |
+
pattern=pattern,
|
1091 |
+
image_size=image_size,
|
1092 |
+
)
|
1093 |
+
parsed_dict['phrase_grounding'] = instances
|
1094 |
+
elif task == 'pure_text':
|
1095 |
+
parsed_dict['pure_text'] = text
|
1096 |
+
elif task == 'description_with_bboxes':
|
1097 |
+
instances = self.parse_description_with_bboxes_from_text_and_spans(
|
1098 |
+
text,
|
1099 |
+
spans=spans,
|
1100 |
+
scores=transition_beam_score,
|
1101 |
+
score_mode=score_mode,
|
1102 |
+
pattern=pattern,
|
1103 |
+
image_size=image_size,
|
1104 |
+
)
|
1105 |
+
parsed_dict['description_with_bboxes'] = instances
|
1106 |
+
elif task == 'description_with_polygons':
|
1107 |
+
instances = self.parse_description_with_polygons_from_text_and_spans(
|
1108 |
+
text,
|
1109 |
+
pattern=pattern,
|
1110 |
+
image_size=image_size,
|
1111 |
+
)
|
1112 |
+
parsed_dict['description_with_polygons'] = instances
|
1113 |
+
elif task == 'polygons':
|
1114 |
+
instances = self.parse_description_with_polygons_from_text_and_spans(
|
1115 |
+
text,
|
1116 |
+
pattern=pattern,
|
1117 |
+
image_size=image_size,
|
1118 |
+
allow_empty_phrase=True,
|
1119 |
+
)
|
1120 |
+
parsed_dict['polygons'] = instances
|
1121 |
+
elif task == 'bboxes':
|
1122 |
+
instances = self.parse_description_with_bboxes_from_text_and_spans(
|
1123 |
+
text,
|
1124 |
+
pattern=pattern,
|
1125 |
+
image_size=image_size,
|
1126 |
+
allow_empty_phrase=True,
|
1127 |
+
)
|
1128 |
+
parsed_dict['bboxes'] = instances
|
1129 |
+
elif task == 'description_with_bboxes_or_polygons':
|
1130 |
+
if '<poly>' in text:
|
1131 |
+
# only support either polygons or bboxes, not both at the same time
|
1132 |
+
instances = self.parse_description_with_polygons_from_text_and_spans(
|
1133 |
+
text,
|
1134 |
+
pattern=pattern,
|
1135 |
+
image_size=image_size,
|
1136 |
+
)
|
1137 |
+
else:
|
1138 |
+
instances = self.parse_description_with_bboxes_from_text_and_spans(
|
1139 |
+
text,
|
1140 |
+
pattern=pattern,
|
1141 |
+
image_size=image_size,
|
1142 |
+
)
|
1143 |
+
parsed_dict['description_with_bboxes_or_polygons'] = instances
|
1144 |
+
else:
|
1145 |
+
raise ValueError("task {} is not supported".format(task))
|
1146 |
+
|
1147 |
+
return parsed_dict
|
eval/grounded_sam/florence2/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
eval/grounded_sam/florence2/tokenizer_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_max_length": 1024
|
3 |
+
}
|
4 |
+
|
eval/grounded_sam/florence2/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
eval/grounded_sam/grounded_sam2_florence2_autolabel_pipeline.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
import supervision as sv
|
7 |
+
from PIL import Image
|
8 |
+
import gc
|
9 |
+
import sys
|
10 |
+
|
11 |
+
from eval.grounded_sam.florence2.modeling_florence2 import Florence2ForConditionalGeneration
|
12 |
+
from eval.grounded_sam.florence2.processing_florence2 import Florence2Processor
|
13 |
+
from eval.grounded_sam.sam2.build_sam import build_sam2
|
14 |
+
from eval.grounded_sam.sam2.sam2_image_predictor import SAM2ImagePredictor
|
15 |
+
|
16 |
+
|
17 |
+
class FlorenceSAM:
|
18 |
+
|
19 |
+
# official usage: https://huggingface.co/microsoft/Florence-2-large/blob/main/sample_inference.ipynb
|
20 |
+
TASK_PROMPT = {
|
21 |
+
"original": "<GIVEN>",
|
22 |
+
"caption": "<CAPTION>",
|
23 |
+
"detailed_caption": "<DETAILED_CAPTION>",
|
24 |
+
"more_detailed_caption": "<MORE_DETAILED_CAPTION>",
|
25 |
+
"object_detection": "<OD>",
|
26 |
+
"dense_region_caption": "<DENSE_REGION_CAPTION>",
|
27 |
+
"region_proposal": "<REGION_PROPOSAL>",
|
28 |
+
"phrase_grounding": "<CAPTION_TO_PHRASE_GROUNDING>",
|
29 |
+
"referring_expression_segmentation": "<REFERRING_EXPRESSION_SEGMENTATION>",
|
30 |
+
"region_to_segmentation": "<REGION_TO_SEGMENTATION>",
|
31 |
+
"open_vocabulary_detection": "<OPEN_VOCABULARY_DETECTION>",
|
32 |
+
"region_to_category": "<REGION_TO_CATEGORY>",
|
33 |
+
"region_to_description": "<REGION_TO_DESCRIPTION>",
|
34 |
+
"ocr": "<OCR>",
|
35 |
+
"ocr_with_region": "<OCR_WITH_REGION>",
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
def __init__(self, device):
|
40 |
+
"""
|
41 |
+
Init Florence-2 and SAM 2 Model
|
42 |
+
"""
|
43 |
+
print(f"[{self}] init on device {device}")
|
44 |
+
self.device = torch.device(device)
|
45 |
+
|
46 |
+
# with torch.autocast(device_type="cuda", dtype=torch.float32).__enter__()
|
47 |
+
# self.torch_dtype = torch.float32
|
48 |
+
# self.torch_dtype = torch.float16
|
49 |
+
self.torch_dtype = torch.bfloat16
|
50 |
+
|
51 |
+
try:
|
52 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
53 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
54 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
55 |
+
torch.backends.cudnn.allow_tf32 = True
|
56 |
+
# self.torch_dtype = torch.bfloat16
|
57 |
+
# else:
|
58 |
+
# self.torch_dtype = torch.float16
|
59 |
+
except:
|
60 |
+
self.torch_dtype = torch.bfloat16
|
61 |
+
|
62 |
+
FLORENCE2_MODEL_ID = os.getenv('FLORENCE2_MODEL_PATH', "microsoft/Florence-2-large")
|
63 |
+
SAM2_CHECKPOINT = os.getenv('SAM2_MODEL_PATH')
|
64 |
+
SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
65 |
+
|
66 |
+
self.florence2_model = Florence2ForConditionalGeneration.from_pretrained(
|
67 |
+
FLORENCE2_MODEL_ID,
|
68 |
+
torch_dtype=self.torch_dtype,
|
69 |
+
).eval().to(self.device)
|
70 |
+
self.florence2_processor = Florence2Processor.from_pretrained(
|
71 |
+
FLORENCE2_MODEL_ID,
|
72 |
+
)
|
73 |
+
sam2_model = build_sam2(SAM2_CONFIG, SAM2_CHECKPOINT, device=self.device)
|
74 |
+
self.sam2_predictor = SAM2ImagePredictor(sam2_model)
|
75 |
+
|
76 |
+
def __str__(self):
|
77 |
+
return "FlorenceSAM"
|
78 |
+
|
79 |
+
|
80 |
+
@torch.no_grad()
|
81 |
+
def run_florence2(self, task_prompt, text_input, image):
|
82 |
+
model = self.florence2_model
|
83 |
+
processor = self.florence2_processor
|
84 |
+
device = self.device
|
85 |
+
assert model is not None, "You should pass the init florence-2 model here"
|
86 |
+
assert processor is not None, "You should set florence-2 processor here"
|
87 |
+
|
88 |
+
with torch.autocast(device_type="cuda", dtype=torch.float32):
|
89 |
+
if text_input is None:
|
90 |
+
prompt = task_prompt
|
91 |
+
else:
|
92 |
+
prompt = task_prompt + text_input
|
93 |
+
|
94 |
+
inputs = processor(
|
95 |
+
text=prompt, images=image,
|
96 |
+
max_length=1024,
|
97 |
+
truncation=True,
|
98 |
+
return_tensors="pt",
|
99 |
+
).to(device, self.torch_dtype)
|
100 |
+
# inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, self.torch_dtype)
|
101 |
+
generated_ids = model.generate(
|
102 |
+
input_ids=inputs["input_ids"].to(device),
|
103 |
+
pixel_values=inputs["pixel_values"].to(device),
|
104 |
+
# max_new_tokens=1024,
|
105 |
+
max_new_tokens=768,
|
106 |
+
early_stopping=False,
|
107 |
+
do_sample=False,
|
108 |
+
num_beams=3,
|
109 |
+
)
|
110 |
+
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
111 |
+
parsed_answer = processor.post_process_generation(
|
112 |
+
generated_text,
|
113 |
+
task=task_prompt,
|
114 |
+
image_size=(image.width, image.height)
|
115 |
+
)
|
116 |
+
return parsed_answer
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
def caption(self, image, caption_task_prompt='<CAPTION>'):
|
121 |
+
assert caption_task_prompt in ["<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>"]
|
122 |
+
caption_results = self.run_florence2(caption_task_prompt, None, image)
|
123 |
+
text_input = caption_results[caption_task_prompt]
|
124 |
+
caption = text_input
|
125 |
+
return caption
|
126 |
+
|
127 |
+
|
128 |
+
def segmentation(self, image, input_boxes, seg_model="sam"):
|
129 |
+
if seg_model == "sam":
|
130 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float32):
|
131 |
+
sam2_predictor = self.sam2_predictor
|
132 |
+
sam2_predictor.set_image(np.array(image))
|
133 |
+
masks, scores, logits = sam2_predictor.predict(
|
134 |
+
point_coords=None,
|
135 |
+
point_labels=None,
|
136 |
+
box=input_boxes,
|
137 |
+
multimask_output=False,
|
138 |
+
)
|
139 |
+
if masks.ndim == 4:
|
140 |
+
masks = masks.squeeze(1)
|
141 |
+
if scores.ndim == 2:
|
142 |
+
scores = scores.squeeze(1)
|
143 |
+
else:
|
144 |
+
raise NotImplementedError()
|
145 |
+
|
146 |
+
return masks, scores
|
147 |
+
|
148 |
+
def post_process_results(self, image, caption, labels, detections, output_dir=None):
|
149 |
+
result_dict = {
|
150 |
+
"caption": caption,
|
151 |
+
"instance_images": [],
|
152 |
+
"instance_labels": [],
|
153 |
+
"instance_bboxes": [],
|
154 |
+
"instance_mask_scores": [],
|
155 |
+
}
|
156 |
+
|
157 |
+
if detections is None:
|
158 |
+
return detections, result_dict
|
159 |
+
|
160 |
+
if output_dir is not None:
|
161 |
+
os.makedirs(output_dir, exist_ok=True)
|
162 |
+
|
163 |
+
cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
164 |
+
|
165 |
+
box_annotator = sv.BoxAnnotator()
|
166 |
+
annotated_frame = box_annotator.annotate(scene=cv_image.copy(), detections=detections)
|
167 |
+
|
168 |
+
label_annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER)
|
169 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
170 |
+
if output_dir is not None:
|
171 |
+
cv2.imwrite(os.path.join(output_dir, "detections.jpg"), annotated_frame)
|
172 |
+
|
173 |
+
mask_annotator = sv.MaskAnnotator()
|
174 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
175 |
+
if output_dir is not None:
|
176 |
+
cv2.imwrite(os.path.join(output_dir, "masks.jpg"), annotated_frame)
|
177 |
+
|
178 |
+
for detection in detections:
|
179 |
+
xyxy, mask, confidence, class_id, tracker_id, data = detection
|
180 |
+
|
181 |
+
label = labels[class_id]
|
182 |
+
cropped_img = sv.crop_image(image=cv_image, xyxy=xyxy)
|
183 |
+
if output_dir is not None:
|
184 |
+
cv2.imwrite(os.path.join(output_dir, f"cropped_image_{label}.jpg"), cropped_img)
|
185 |
+
|
186 |
+
if mask is None:
|
187 |
+
result_dict["instance_mask_scores"].append(0)
|
188 |
+
result_dict["instance_images"].append(cropped_img)
|
189 |
+
else:
|
190 |
+
mask = np.repeat(mask[..., np.newaxis], 3, axis=-1)
|
191 |
+
masked_img = np.where(mask, cv_image, 255)
|
192 |
+
cropped_masked_img = sv.crop_image(image=masked_img, xyxy=xyxy)
|
193 |
+
result_dict["instance_mask_scores"].append(confidence.item())
|
194 |
+
result_dict["instance_images"].append(cropped_masked_img)
|
195 |
+
|
196 |
+
result_dict["instance_labels"].append(label)
|
197 |
+
result_dict["instance_bboxes"].append(xyxy)
|
198 |
+
if output_dir is not None:
|
199 |
+
cv2.imwrite(os.path.join(output_dir, f"masked_image_{label}.jpg"), cropped_masked_img)
|
200 |
+
|
201 |
+
torch.cuda.empty_cache()
|
202 |
+
gc.collect()
|
203 |
+
return detections, result_dict
|
204 |
+
|
205 |
+
def caption_phrase_grounding_and_segmentation(
|
206 |
+
self,
|
207 |
+
image,
|
208 |
+
seg_model="sam",
|
209 |
+
caption_task_prompt='<CAPTION>',
|
210 |
+
original_caption=None,
|
211 |
+
output_dir=None
|
212 |
+
):
|
213 |
+
|
214 |
+
assert caption_task_prompt in ["<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>", "<GIVEN>", "<OPEN_VOCABULARY_DETECTION>"]
|
215 |
+
assert seg_model in ["sam", "florence2"]
|
216 |
+
|
217 |
+
# image caption
|
218 |
+
if caption_task_prompt in ["<GIVEN>", "<OPEN_VOCABULARY_DETECTION>"]:
|
219 |
+
assert original_caption is not None
|
220 |
+
caption = original_caption
|
221 |
+
else:
|
222 |
+
caption_results = self.run_florence2(caption_task_prompt, None, image)
|
223 |
+
text_input = caption_results[caption_task_prompt]
|
224 |
+
caption = text_input
|
225 |
+
|
226 |
+
# phrase grounding
|
227 |
+
grounding_results = self.run_florence2('<CAPTION_TO_PHRASE_GROUNDING>', caption, image)['<CAPTION_TO_PHRASE_GROUNDING>']
|
228 |
+
input_boxes = np.array(grounding_results["bboxes"])
|
229 |
+
class_names = grounding_results["labels"]
|
230 |
+
class_ids = np.array(list(range(len(class_names))))
|
231 |
+
|
232 |
+
# segmentation
|
233 |
+
masks, scores = self.segmentation(image, input_boxes, seg_model)
|
234 |
+
|
235 |
+
labels = [f"{class_name}" for class_name in class_names]
|
236 |
+
detections = sv.Detections(
|
237 |
+
xyxy=input_boxes,
|
238 |
+
mask=masks.astype(bool),
|
239 |
+
class_id=class_ids,
|
240 |
+
confidence=scores,
|
241 |
+
)
|
242 |
+
|
243 |
+
return self.post_process_results(image, caption, labels, detections, output_dir)
|
244 |
+
|
245 |
+
def od_grounding_and_segmentation(
|
246 |
+
self,
|
247 |
+
image,
|
248 |
+
text_input,
|
249 |
+
seg_model="sam",
|
250 |
+
output_dir=None
|
251 |
+
):
|
252 |
+
assert seg_model in ["sam", "florence2"]
|
253 |
+
|
254 |
+
# od grounding
|
255 |
+
grounding_results = self.run_florence2('<OPEN_VOCABULARY_DETECTION>', text_input, image)['<OPEN_VOCABULARY_DETECTION>']
|
256 |
+
if len(grounding_results["bboxes"]) == 0:
|
257 |
+
detections = None
|
258 |
+
labels = []
|
259 |
+
else:
|
260 |
+
input_boxes = np.array(grounding_results["bboxes"])
|
261 |
+
class_names = grounding_results["bboxes_labels"]
|
262 |
+
class_ids = np.array(list(range(len(class_names))))
|
263 |
+
|
264 |
+
# segmentation
|
265 |
+
masks, scores = self.segmentation(image, input_boxes, seg_model)
|
266 |
+
|
267 |
+
labels = [f"{class_name}" for class_name in class_names]
|
268 |
+
detections = sv.Detections(
|
269 |
+
xyxy=input_boxes,
|
270 |
+
mask=masks.astype(bool),
|
271 |
+
class_id=class_ids,
|
272 |
+
confidence=scores,
|
273 |
+
)
|
274 |
+
|
275 |
+
return self.post_process_results(image, text_input, labels, detections, output_dir)
|
276 |
+
|
277 |
+
def od_grounding(
|
278 |
+
self,
|
279 |
+
image,
|
280 |
+
text_input,
|
281 |
+
output_dir=None
|
282 |
+
):
|
283 |
+
|
284 |
+
# od grounding
|
285 |
+
grounding_results = self.run_florence2('<OPEN_VOCABULARY_DETECTION>', text_input, image)['<OPEN_VOCABULARY_DETECTION>']
|
286 |
+
if len(grounding_results["bboxes"]) == 0:
|
287 |
+
detections = None
|
288 |
+
labels = []
|
289 |
+
else:
|
290 |
+
input_boxes = np.array(grounding_results["bboxes"])
|
291 |
+
class_names = grounding_results["bboxes_labels"]
|
292 |
+
class_ids = np.array(list(range(len(class_names))))
|
293 |
+
|
294 |
+
labels = [f"{class_name}" for class_name in class_names]
|
295 |
+
detections = sv.Detections(
|
296 |
+
xyxy=input_boxes,
|
297 |
+
class_id=class_ids,
|
298 |
+
)
|
299 |
+
|
300 |
+
return self.post_process_results(image, text_input, labels, detections, output_dir)
|
301 |
+
|
302 |
+
def phrase_grounding_and_segmentation(
|
303 |
+
self,
|
304 |
+
image,
|
305 |
+
text_input,
|
306 |
+
seg_model="sam",
|
307 |
+
output_dir=None
|
308 |
+
):
|
309 |
+
assert seg_model in ["sam", "florence2"]
|
310 |
+
|
311 |
+
# phrase grounding
|
312 |
+
grounding_results = self.run_florence2('<CAPTION_TO_PHRASE_GROUNDING>', text_input, image)['<CAPTION_TO_PHRASE_GROUNDING>']
|
313 |
+
input_boxes = np.array(grounding_results["bboxes"])
|
314 |
+
class_names = grounding_results["labels"]
|
315 |
+
# print(f"[phrase_grounding_and_segmentation] input_label={text_input}, output_label={class_names}")
|
316 |
+
class_ids = np.array(list(range(len(class_names))))
|
317 |
+
|
318 |
+
# segmentation
|
319 |
+
masks, scores = self.segmentation(image, input_boxes, seg_model)
|
320 |
+
|
321 |
+
labels = [f"{class_name}" for class_name in class_names]
|
322 |
+
detections = sv.Detections(
|
323 |
+
xyxy=input_boxes,
|
324 |
+
mask=masks.astype(bool),
|
325 |
+
class_id=class_ids,
|
326 |
+
confidence=scores,
|
327 |
+
)
|
328 |
+
|
329 |
+
return self.post_process_results(image, text_input, labels, detections, output_dir)
|
330 |
+
|
331 |
+
|
332 |
+
if __name__ == "__main__":
|
333 |
+
|
334 |
+
parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True)
|
335 |
+
parser.add_argument("--image_path", type=str, default="./notebooks/images/cars.jpg", required=True, help="path to image file")
|
336 |
+
parser.add_argument("--caption_type", type=str, default="caption", required=False, help="granularity of caption")
|
337 |
+
args = parser.parse_args()
|
338 |
+
|
339 |
+
|
340 |
+
|
341 |
+
# IMAGE_PATH = args.image_path
|
342 |
+
PIPELINE = "caption_to_phrase_grounding"
|
343 |
+
CAPTION_TYPE = args.caption_type
|
344 |
+
assert CAPTION_TYPE in ["caption", "detailed_caption", "more_detailed_caption", "original"]
|
345 |
+
|
346 |
+
print(f"Running pipeline: {PIPELINE} now.")
|
347 |
+
|
348 |
+
pipeline = FlorenceSAM("cuda:0")
|
349 |
+
|
350 |
+
from glob import glob
|
351 |
+
from tqdm import tqdm
|
352 |
+
for image_path in tqdm(glob("/mnt/bn/lq-prompt-alignment/personal/chenbowen/code/IPVerse/prompt_alignment/Grounded-SAM-2/notebooks/images/*") * 3):
|
353 |
+
# for image_path in tqdm(glob("/mnt/bn/lq-prompt-alignment/personal/chenbowen/code/IPVerse/prompt_alignment/Grounded-SAM-2/outputs/gcg_pipeline/00001.tar_debug/*.png")):
|
354 |
+
print(pipeline.TASK_PROMPT, CAPTION_TYPE)
|
355 |
+
image = Image.open(image_path).convert("RGB")
|
356 |
+
pipeline.caption_phrase_grounding_and_segmentation(
|
357 |
+
image=image,
|
358 |
+
seg_model="sam",
|
359 |
+
caption_task_prompt=pipeline.TASK_PROMPT[CAPTION_TYPE],
|
360 |
+
output_dir=f"./outputs/{os.path.basename(image_path)}"
|
361 |
+
)
|
eval/grounded_sam/sam2/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from hydra import initialize_config_module
|
8 |
+
from hydra.core.global_hydra import GlobalHydra
|
9 |
+
|
10 |
+
if not GlobalHydra.instance().is_initialized():
|
11 |
+
initialize_config_module("sam2", version_base="1.2")
|
eval/grounded_sam/sam2/automatic_mask_generator.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
|
8 |
+
from typing import Any, Dict, List, Optional, Tuple
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
13 |
+
|
14 |
+
from sam2.modeling.sam2_base import SAM2Base
|
15 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
16 |
+
from sam2.utils.amg import (
|
17 |
+
area_from_rle,
|
18 |
+
batch_iterator,
|
19 |
+
batched_mask_to_box,
|
20 |
+
box_xyxy_to_xywh,
|
21 |
+
build_all_layer_point_grids,
|
22 |
+
calculate_stability_score,
|
23 |
+
coco_encode_rle,
|
24 |
+
generate_crop_boxes,
|
25 |
+
is_box_near_crop_edge,
|
26 |
+
mask_to_rle_pytorch,
|
27 |
+
MaskData,
|
28 |
+
remove_small_regions,
|
29 |
+
rle_to_mask,
|
30 |
+
uncrop_boxes_xyxy,
|
31 |
+
uncrop_masks,
|
32 |
+
uncrop_points,
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
class SAM2AutomaticMaskGenerator:
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
model: SAM2Base,
|
40 |
+
points_per_side: Optional[int] = 32,
|
41 |
+
points_per_batch: int = 64,
|
42 |
+
pred_iou_thresh: float = 0.8,
|
43 |
+
stability_score_thresh: float = 0.95,
|
44 |
+
stability_score_offset: float = 1.0,
|
45 |
+
mask_threshold: float = 0.0,
|
46 |
+
box_nms_thresh: float = 0.7,
|
47 |
+
crop_n_layers: int = 0,
|
48 |
+
crop_nms_thresh: float = 0.7,
|
49 |
+
crop_overlap_ratio: float = 512 / 1500,
|
50 |
+
crop_n_points_downscale_factor: int = 1,
|
51 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
52 |
+
min_mask_region_area: int = 0,
|
53 |
+
output_mode: str = "binary_mask",
|
54 |
+
use_m2m: bool = False,
|
55 |
+
multimask_output: bool = True,
|
56 |
+
**kwargs,
|
57 |
+
) -> None:
|
58 |
+
"""
|
59 |
+
Using a SAM 2 model, generates masks for the entire image.
|
60 |
+
Generates a grid of point prompts over the image, then filters
|
61 |
+
low quality and duplicate masks. The default settings are chosen
|
62 |
+
for SAM 2 with a HieraL backbone.
|
63 |
+
|
64 |
+
Arguments:
|
65 |
+
model (Sam): The SAM 2 model to use for mask prediction.
|
66 |
+
points_per_side (int or None): The number of points to be sampled
|
67 |
+
along one side of the image. The total number of points is
|
68 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
69 |
+
point sampling.
|
70 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
71 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
72 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
73 |
+
model's predicted mask quality.
|
74 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
75 |
+
the stability of the mask under changes to the cutoff used to binarize
|
76 |
+
the model's mask predictions.
|
77 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
78 |
+
calculated the stability score.
|
79 |
+
mask_threshold (float): Threshold for binarizing the mask logits
|
80 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
81 |
+
suppression to filter duplicate masks.
|
82 |
+
crop_n_layers (int): If >0, mask prediction will be run again on
|
83 |
+
crops of the image. Sets the number of layers to run, where each
|
84 |
+
layer has 2**i_layer number of image crops.
|
85 |
+
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
86 |
+
suppression to filter duplicate masks between different crops.
|
87 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
88 |
+
In the first crop layer, crops will overlap by this fraction of
|
89 |
+
the image length. Later layers with more crops scale down this overlap.
|
90 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
91 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
92 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
93 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
94 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
95 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
96 |
+
to remove disconnected regions and holes in masks with area smaller
|
97 |
+
than min_mask_region_area. Requires opencv.
|
98 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
99 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
100 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
101 |
+
memory.
|
102 |
+
use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
|
103 |
+
multimask_output (bool): Whether to output multimask at each point of the grid.
|
104 |
+
"""
|
105 |
+
|
106 |
+
assert (points_per_side is None) != (
|
107 |
+
point_grids is None
|
108 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
109 |
+
if points_per_side is not None:
|
110 |
+
self.point_grids = build_all_layer_point_grids(
|
111 |
+
points_per_side,
|
112 |
+
crop_n_layers,
|
113 |
+
crop_n_points_downscale_factor,
|
114 |
+
)
|
115 |
+
elif point_grids is not None:
|
116 |
+
self.point_grids = point_grids
|
117 |
+
else:
|
118 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
119 |
+
|
120 |
+
assert output_mode in [
|
121 |
+
"binary_mask",
|
122 |
+
"uncompressed_rle",
|
123 |
+
"coco_rle",
|
124 |
+
], f"Unknown output_mode {output_mode}."
|
125 |
+
if output_mode == "coco_rle":
|
126 |
+
try:
|
127 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
128 |
+
except ImportError as e:
|
129 |
+
print("Please install pycocotools")
|
130 |
+
raise e
|
131 |
+
|
132 |
+
self.predictor = SAM2ImagePredictor(
|
133 |
+
model,
|
134 |
+
max_hole_area=min_mask_region_area,
|
135 |
+
max_sprinkle_area=min_mask_region_area,
|
136 |
+
)
|
137 |
+
self.points_per_batch = points_per_batch
|
138 |
+
self.pred_iou_thresh = pred_iou_thresh
|
139 |
+
self.stability_score_thresh = stability_score_thresh
|
140 |
+
self.stability_score_offset = stability_score_offset
|
141 |
+
self.mask_threshold = mask_threshold
|
142 |
+
self.box_nms_thresh = box_nms_thresh
|
143 |
+
self.crop_n_layers = crop_n_layers
|
144 |
+
self.crop_nms_thresh = crop_nms_thresh
|
145 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
146 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
147 |
+
self.min_mask_region_area = min_mask_region_area
|
148 |
+
self.output_mode = output_mode
|
149 |
+
self.use_m2m = use_m2m
|
150 |
+
self.multimask_output = multimask_output
|
151 |
+
|
152 |
+
@classmethod
|
153 |
+
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
|
154 |
+
"""
|
155 |
+
Load a pretrained model from the Hugging Face hub.
|
156 |
+
|
157 |
+
Arguments:
|
158 |
+
model_id (str): The Hugging Face repository ID.
|
159 |
+
**kwargs: Additional arguments to pass to the model constructor.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
(SAM2AutomaticMaskGenerator): The loaded model.
|
163 |
+
"""
|
164 |
+
from sam2.build_sam import build_sam2_hf
|
165 |
+
|
166 |
+
sam_model = build_sam2_hf(model_id, **kwargs)
|
167 |
+
return cls(sam_model, **kwargs)
|
168 |
+
|
169 |
+
@torch.no_grad()
|
170 |
+
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
171 |
+
"""
|
172 |
+
Generates masks for the given image.
|
173 |
+
|
174 |
+
Arguments:
|
175 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
179 |
+
a dict containing the following keys:
|
180 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
181 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
182 |
+
is a dictionary containing the RLE.
|
183 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
184 |
+
area (int): The area in pixels of the mask.
|
185 |
+
predicted_iou (float): The model's own prediction of the mask's
|
186 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
187 |
+
point_coords (list(list(float))): The point coordinates input
|
188 |
+
to the model to generate this mask.
|
189 |
+
stability_score (float): A measure of the mask's quality. This
|
190 |
+
is filtered on using the stability_score_thresh parameter.
|
191 |
+
crop_box (list(float)): The crop of the image used to generate
|
192 |
+
the mask, given in XYWH format.
|
193 |
+
"""
|
194 |
+
|
195 |
+
# Generate masks
|
196 |
+
mask_data = self._generate_masks(image)
|
197 |
+
|
198 |
+
# Encode masks
|
199 |
+
if self.output_mode == "coco_rle":
|
200 |
+
mask_data["segmentations"] = [
|
201 |
+
coco_encode_rle(rle) for rle in mask_data["rles"]
|
202 |
+
]
|
203 |
+
elif self.output_mode == "binary_mask":
|
204 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
205 |
+
else:
|
206 |
+
mask_data["segmentations"] = mask_data["rles"]
|
207 |
+
|
208 |
+
# Write mask records
|
209 |
+
curr_anns = []
|
210 |
+
for idx in range(len(mask_data["segmentations"])):
|
211 |
+
ann = {
|
212 |
+
"segmentation": mask_data["segmentations"][idx],
|
213 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
214 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
215 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
216 |
+
"point_coords": [mask_data["points"][idx].tolist()],
|
217 |
+
"stability_score": mask_data["stability_score"][idx].item(),
|
218 |
+
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
219 |
+
}
|
220 |
+
curr_anns.append(ann)
|
221 |
+
|
222 |
+
return curr_anns
|
223 |
+
|
224 |
+
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
225 |
+
orig_size = image.shape[:2]
|
226 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
227 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
228 |
+
)
|
229 |
+
|
230 |
+
# Iterate over image crops
|
231 |
+
data = MaskData()
|
232 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
233 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
234 |
+
data.cat(crop_data)
|
235 |
+
|
236 |
+
# Remove duplicate masks between crops
|
237 |
+
if len(crop_boxes) > 1:
|
238 |
+
# Prefer masks from smaller crops
|
239 |
+
scores = 1 / box_area(data["crop_boxes"])
|
240 |
+
scores = scores.to(data["boxes"].device)
|
241 |
+
keep_by_nms = batched_nms(
|
242 |
+
data["boxes"].float(),
|
243 |
+
scores,
|
244 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
245 |
+
iou_threshold=self.crop_nms_thresh,
|
246 |
+
)
|
247 |
+
data.filter(keep_by_nms)
|
248 |
+
data.to_numpy()
|
249 |
+
return data
|
250 |
+
|
251 |
+
def _process_crop(
|
252 |
+
self,
|
253 |
+
image: np.ndarray,
|
254 |
+
crop_box: List[int],
|
255 |
+
crop_layer_idx: int,
|
256 |
+
orig_size: Tuple[int, ...],
|
257 |
+
) -> MaskData:
|
258 |
+
# Crop the image and calculate embeddings
|
259 |
+
x0, y0, x1, y1 = crop_box
|
260 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
261 |
+
cropped_im_size = cropped_im.shape[:2]
|
262 |
+
self.predictor.set_image(cropped_im)
|
263 |
+
|
264 |
+
# Get points for this crop
|
265 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
266 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
267 |
+
|
268 |
+
# Generate masks for this crop in batches
|
269 |
+
data = MaskData()
|
270 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
271 |
+
batch_data = self._process_batch(
|
272 |
+
points, cropped_im_size, crop_box, orig_size, normalize=True
|
273 |
+
)
|
274 |
+
data.cat(batch_data)
|
275 |
+
del batch_data
|
276 |
+
self.predictor.reset_predictor()
|
277 |
+
|
278 |
+
# Remove duplicates within this crop.
|
279 |
+
keep_by_nms = batched_nms(
|
280 |
+
data["boxes"].float(),
|
281 |
+
data["iou_preds"],
|
282 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
283 |
+
iou_threshold=self.box_nms_thresh,
|
284 |
+
)
|
285 |
+
data.filter(keep_by_nms)
|
286 |
+
|
287 |
+
# Return to the original image frame
|
288 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
289 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
290 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
291 |
+
|
292 |
+
return data
|
293 |
+
|
294 |
+
def _process_batch(
|
295 |
+
self,
|
296 |
+
points: np.ndarray,
|
297 |
+
im_size: Tuple[int, ...],
|
298 |
+
crop_box: List[int],
|
299 |
+
orig_size: Tuple[int, ...],
|
300 |
+
normalize=False,
|
301 |
+
) -> MaskData:
|
302 |
+
orig_h, orig_w = orig_size
|
303 |
+
|
304 |
+
# Run model on this batch
|
305 |
+
points = torch.as_tensor(
|
306 |
+
points, dtype=torch.float32, device=self.predictor.device
|
307 |
+
)
|
308 |
+
in_points = self.predictor._transforms.transform_coords(
|
309 |
+
points, normalize=normalize, orig_hw=im_size
|
310 |
+
)
|
311 |
+
in_labels = torch.ones(
|
312 |
+
in_points.shape[0], dtype=torch.int, device=in_points.device
|
313 |
+
)
|
314 |
+
masks, iou_preds, low_res_masks = self.predictor._predict(
|
315 |
+
in_points[:, None, :],
|
316 |
+
in_labels[:, None],
|
317 |
+
multimask_output=self.multimask_output,
|
318 |
+
return_logits=True,
|
319 |
+
)
|
320 |
+
|
321 |
+
# Serialize predictions and store in MaskData
|
322 |
+
data = MaskData(
|
323 |
+
masks=masks.flatten(0, 1),
|
324 |
+
iou_preds=iou_preds.flatten(0, 1),
|
325 |
+
points=points.repeat_interleave(masks.shape[1], dim=0),
|
326 |
+
low_res_masks=low_res_masks.flatten(0, 1),
|
327 |
+
)
|
328 |
+
del masks
|
329 |
+
|
330 |
+
if not self.use_m2m:
|
331 |
+
# Filter by predicted IoU
|
332 |
+
if self.pred_iou_thresh > 0.0:
|
333 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
334 |
+
data.filter(keep_mask)
|
335 |
+
|
336 |
+
# Calculate and filter by stability score
|
337 |
+
data["stability_score"] = calculate_stability_score(
|
338 |
+
data["masks"], self.mask_threshold, self.stability_score_offset
|
339 |
+
)
|
340 |
+
if self.stability_score_thresh > 0.0:
|
341 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
342 |
+
data.filter(keep_mask)
|
343 |
+
else:
|
344 |
+
# One step refinement using previous mask predictions
|
345 |
+
in_points = self.predictor._transforms.transform_coords(
|
346 |
+
data["points"], normalize=normalize, orig_hw=im_size
|
347 |
+
)
|
348 |
+
labels = torch.ones(
|
349 |
+
in_points.shape[0], dtype=torch.int, device=in_points.device
|
350 |
+
)
|
351 |
+
masks, ious = self.refine_with_m2m(
|
352 |
+
in_points, labels, data["low_res_masks"], self.points_per_batch
|
353 |
+
)
|
354 |
+
data["masks"] = masks.squeeze(1)
|
355 |
+
data["iou_preds"] = ious.squeeze(1)
|
356 |
+
|
357 |
+
if self.pred_iou_thresh > 0.0:
|
358 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
359 |
+
data.filter(keep_mask)
|
360 |
+
|
361 |
+
data["stability_score"] = calculate_stability_score(
|
362 |
+
data["masks"], self.mask_threshold, self.stability_score_offset
|
363 |
+
)
|
364 |
+
if self.stability_score_thresh > 0.0:
|
365 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
366 |
+
data.filter(keep_mask)
|
367 |
+
|
368 |
+
# Threshold masks and calculate boxes
|
369 |
+
data["masks"] = data["masks"] > self.mask_threshold
|
370 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
371 |
+
|
372 |
+
# Filter boxes that touch crop boundaries
|
373 |
+
keep_mask = ~is_box_near_crop_edge(
|
374 |
+
data["boxes"], crop_box, [0, 0, orig_w, orig_h]
|
375 |
+
)
|
376 |
+
if not torch.all(keep_mask):
|
377 |
+
data.filter(keep_mask)
|
378 |
+
|
379 |
+
# Compress to RLE
|
380 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
381 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
382 |
+
del data["masks"]
|
383 |
+
|
384 |
+
return data
|
385 |
+
|
386 |
+
@staticmethod
|
387 |
+
def postprocess_small_regions(
|
388 |
+
mask_data: MaskData, min_area: int, nms_thresh: float
|
389 |
+
) -> MaskData:
|
390 |
+
"""
|
391 |
+
Removes small disconnected regions and holes in masks, then reruns
|
392 |
+
box NMS to remove any new duplicates.
|
393 |
+
|
394 |
+
Edits mask_data in place.
|
395 |
+
|
396 |
+
Requires open-cv as a dependency.
|
397 |
+
"""
|
398 |
+
if len(mask_data["rles"]) == 0:
|
399 |
+
return mask_data
|
400 |
+
|
401 |
+
# Filter small disconnected regions and holes
|
402 |
+
new_masks = []
|
403 |
+
scores = []
|
404 |
+
for rle in mask_data["rles"]:
|
405 |
+
mask = rle_to_mask(rle)
|
406 |
+
|
407 |
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
408 |
+
unchanged = not changed
|
409 |
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
410 |
+
unchanged = unchanged and not changed
|
411 |
+
|
412 |
+
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
413 |
+
# Give score=0 to changed masks and score=1 to unchanged masks
|
414 |
+
# so NMS will prefer ones that didn't need postprocessing
|
415 |
+
scores.append(float(unchanged))
|
416 |
+
|
417 |
+
# Recalculate boxes and remove any new duplicates
|
418 |
+
masks = torch.cat(new_masks, dim=0)
|
419 |
+
boxes = batched_mask_to_box(masks)
|
420 |
+
keep_by_nms = batched_nms(
|
421 |
+
boxes.float(),
|
422 |
+
torch.as_tensor(scores),
|
423 |
+
torch.zeros_like(boxes[:, 0]), # categories
|
424 |
+
iou_threshold=nms_thresh,
|
425 |
+
)
|
426 |
+
|
427 |
+
# Only recalculate RLEs for masks that have changed
|
428 |
+
for i_mask in keep_by_nms:
|
429 |
+
if scores[i_mask] == 0.0:
|
430 |
+
mask_torch = masks[i_mask].unsqueeze(0)
|
431 |
+
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
432 |
+
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
433 |
+
mask_data.filter(keep_by_nms)
|
434 |
+
|
435 |
+
return mask_data
|
436 |
+
|
437 |
+
def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
|
438 |
+
new_masks = []
|
439 |
+
new_iou_preds = []
|
440 |
+
|
441 |
+
for cur_points, cur_point_labels, low_res_mask in batch_iterator(
|
442 |
+
points_per_batch, points, point_labels, low_res_masks
|
443 |
+
):
|
444 |
+
best_masks, best_iou_preds, _ = self.predictor._predict(
|
445 |
+
cur_points[:, None, :],
|
446 |
+
cur_point_labels[:, None],
|
447 |
+
mask_input=low_res_mask[:, None, :],
|
448 |
+
multimask_output=False,
|
449 |
+
return_logits=True,
|
450 |
+
)
|
451 |
+
new_masks.append(best_masks)
|
452 |
+
new_iou_preds.append(best_iou_preds)
|
453 |
+
masks = torch.cat(new_masks, dim=0)
|
454 |
+
return masks, torch.cat(new_iou_preds, dim=0)
|
eval/grounded_sam/sam2/build_sam.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import torch
|
11 |
+
from hydra import compose
|
12 |
+
from hydra.utils import instantiate
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
|
15 |
+
from pathlib import Path
|
16 |
+
current_dir = str(Path(os.path.abspath('')))
|
17 |
+
sam_dir = os.path.join(current_dir, "eval/grounded_sam")
|
18 |
+
sys.path.append(sam_dir)
|
19 |
+
|
20 |
+
import sam2
|
21 |
+
|
22 |
+
# # Check if the user is running Python from the parent directory of the sam2 repo
|
23 |
+
# # (i.e. the directory where this repo is cloned into) -- this is not supported since
|
24 |
+
# # it could shadow the sam2 package and cause issues.
|
25 |
+
# if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
|
26 |
+
# # If the user has "sam2/sam2" in their path, they are likey importing the repo itself
|
27 |
+
# # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
|
28 |
+
# # This typically happens because the user is running Python from the parent directory
|
29 |
+
# # that contains the sam2 repo they cloned.
|
30 |
+
# raise RuntimeError(
|
31 |
+
# "You're likely running Python from the parent directory of the sam2 repository "
|
32 |
+
# "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
|
33 |
+
# "This is not supported since the `sam2` Python package could be shadowed by the "
|
34 |
+
# "repository name (the repository is also named `sam2` and contains the Python package "
|
35 |
+
# "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
|
36 |
+
# "rather than its parent dir, or from your home directory) after installing SAM 2."
|
37 |
+
# )
|
38 |
+
|
39 |
+
|
40 |
+
HF_MODEL_ID_TO_FILENAMES = {
|
41 |
+
"facebook/sam2-hiera-tiny": (
|
42 |
+
"configs/sam2/sam2_hiera_t.yaml",
|
43 |
+
"sam2_hiera_tiny.pt",
|
44 |
+
),
|
45 |
+
"facebook/sam2-hiera-small": (
|
46 |
+
"configs/sam2/sam2_hiera_s.yaml",
|
47 |
+
"sam2_hiera_small.pt",
|
48 |
+
),
|
49 |
+
"facebook/sam2-hiera-base-plus": (
|
50 |
+
"configs/sam2/sam2_hiera_b+.yaml",
|
51 |
+
"sam2_hiera_base_plus.pt",
|
52 |
+
),
|
53 |
+
"facebook/sam2-hiera-large": (
|
54 |
+
"configs/sam2/sam2_hiera_l.yaml",
|
55 |
+
"sam2_hiera_large.pt",
|
56 |
+
),
|
57 |
+
"facebook/sam2.1-hiera-tiny": (
|
58 |
+
"configs/sam2.1/sam2.1_hiera_t.yaml",
|
59 |
+
"sam2.1_hiera_tiny.pt",
|
60 |
+
),
|
61 |
+
"facebook/sam2.1-hiera-small": (
|
62 |
+
"configs/sam2.1/sam2.1_hiera_s.yaml",
|
63 |
+
"sam2.1_hiera_small.pt",
|
64 |
+
),
|
65 |
+
"facebook/sam2.1-hiera-base-plus": (
|
66 |
+
"configs/sam2.1/sam2.1_hiera_b+.yaml",
|
67 |
+
"sam2.1_hiera_base_plus.pt",
|
68 |
+
),
|
69 |
+
"facebook/sam2.1-hiera-large": (
|
70 |
+
"configs/sam2.1/sam2.1_hiera_l.yaml",
|
71 |
+
"sam2.1_hiera_large.pt",
|
72 |
+
),
|
73 |
+
}
|
74 |
+
|
75 |
+
|
76 |
+
def build_sam2(
|
77 |
+
config_file,
|
78 |
+
ckpt_path=None,
|
79 |
+
device="cuda",
|
80 |
+
mode="eval",
|
81 |
+
hydra_overrides_extra=[],
|
82 |
+
apply_postprocessing=True,
|
83 |
+
**kwargs,
|
84 |
+
):
|
85 |
+
|
86 |
+
if apply_postprocessing:
|
87 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
88 |
+
hydra_overrides_extra += [
|
89 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
90 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
91 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
92 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
93 |
+
]
|
94 |
+
# Read config and init model
|
95 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
|
96 |
+
OmegaConf.resolve(cfg)
|
97 |
+
model = instantiate(cfg.model, _recursive_=True)
|
98 |
+
_load_checkpoint(model, ckpt_path)
|
99 |
+
model = model.to(device)
|
100 |
+
if mode == "eval":
|
101 |
+
model.eval()
|
102 |
+
return model
|
103 |
+
|
104 |
+
|
105 |
+
def build_sam2_video_predictor(
|
106 |
+
config_file,
|
107 |
+
ckpt_path=None,
|
108 |
+
device="cuda",
|
109 |
+
mode="eval",
|
110 |
+
hydra_overrides_extra=[],
|
111 |
+
apply_postprocessing=True,
|
112 |
+
**kwargs,
|
113 |
+
):
|
114 |
+
hydra_overrides = [
|
115 |
+
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
116 |
+
]
|
117 |
+
if apply_postprocessing:
|
118 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
119 |
+
hydra_overrides_extra += [
|
120 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
121 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
122 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
123 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
124 |
+
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
|
125 |
+
"++model.binarize_mask_from_pts_for_mem_enc=true",
|
126 |
+
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
|
127 |
+
"++model.fill_hole_area=8",
|
128 |
+
]
|
129 |
+
hydra_overrides.extend(hydra_overrides_extra)
|
130 |
+
|
131 |
+
# Read config and init model
|
132 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides)
|
133 |
+
OmegaConf.resolve(cfg)
|
134 |
+
model = instantiate(cfg.model, _recursive_=True)
|
135 |
+
_load_checkpoint(model, ckpt_path)
|
136 |
+
model = model.to(device)
|
137 |
+
if mode == "eval":
|
138 |
+
model.eval()
|
139 |
+
return model
|
140 |
+
|
141 |
+
|
142 |
+
def _hf_download(model_id):
|
143 |
+
from huggingface_hub import hf_hub_download
|
144 |
+
|
145 |
+
config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
|
146 |
+
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
|
147 |
+
return config_name, ckpt_path
|
148 |
+
|
149 |
+
|
150 |
+
def build_sam2_hf(model_id, **kwargs):
|
151 |
+
config_name, ckpt_path = _hf_download(model_id)
|
152 |
+
return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
|
153 |
+
|
154 |
+
|
155 |
+
def build_sam2_video_predictor_hf(model_id, **kwargs):
|
156 |
+
config_name, ckpt_path = _hf_download(model_id)
|
157 |
+
return build_sam2_video_predictor(
|
158 |
+
config_file=config_name, ckpt_path=ckpt_path, **kwargs
|
159 |
+
)
|
160 |
+
|
161 |
+
|
162 |
+
def _load_checkpoint(model, ckpt_path):
|
163 |
+
if ckpt_path is not None:
|
164 |
+
sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
|
165 |
+
missing_keys, unexpected_keys = model.load_state_dict(sd)
|
166 |
+
if missing_keys:
|
167 |
+
logging.error(missing_keys)
|
168 |
+
raise RuntimeError()
|
169 |
+
if unexpected_keys:
|
170 |
+
logging.error(unexpected_keys)
|
171 |
+
raise RuntimeError()
|
172 |
+
logging.info("Loaded checkpoint sucessfully")
|
eval/grounded_sam/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 112
|
12 |
+
num_heads: 2
|
13 |
+
neck:
|
14 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
15 |
+
position_encoding:
|
16 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
17 |
+
num_pos_feats: 256
|
18 |
+
normalize: true
|
19 |
+
scale: null
|
20 |
+
temperature: 10000
|
21 |
+
d_model: 256
|
22 |
+
backbone_channel_list: [896, 448, 224, 112]
|
23 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
24 |
+
fpn_interp_model: nearest
|
25 |
+
|
26 |
+
memory_attention:
|
27 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
28 |
+
d_model: 256
|
29 |
+
pos_enc_at_input: true
|
30 |
+
layer:
|
31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
32 |
+
activation: relu
|
33 |
+
dim_feedforward: 2048
|
34 |
+
dropout: 0.1
|
35 |
+
pos_enc_at_attn: false
|
36 |
+
self_attention:
|
37 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
38 |
+
rope_theta: 10000.0
|
39 |
+
feat_sizes: [32, 32]
|
40 |
+
embedding_dim: 256
|
41 |
+
num_heads: 1
|
42 |
+
downsample_rate: 1
|
43 |
+
dropout: 0.1
|
44 |
+
d_model: 256
|
45 |
+
pos_enc_at_cross_attn_keys: true
|
46 |
+
pos_enc_at_cross_attn_queries: false
|
47 |
+
cross_attention:
|
48 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
49 |
+
rope_theta: 10000.0
|
50 |
+
feat_sizes: [32, 32]
|
51 |
+
rope_k_repeat: True
|
52 |
+
embedding_dim: 256
|
53 |
+
num_heads: 1
|
54 |
+
downsample_rate: 1
|
55 |
+
dropout: 0.1
|
56 |
+
kv_in_dim: 64
|
57 |
+
num_layers: 4
|
58 |
+
|
59 |
+
memory_encoder:
|
60 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
61 |
+
out_dim: 64
|
62 |
+
position_encoding:
|
63 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
64 |
+
num_pos_feats: 64
|
65 |
+
normalize: true
|
66 |
+
scale: null
|
67 |
+
temperature: 10000
|
68 |
+
mask_downsampler:
|
69 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
70 |
+
kernel_size: 3
|
71 |
+
stride: 2
|
72 |
+
padding: 1
|
73 |
+
fuser:
|
74 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
75 |
+
layer:
|
76 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
77 |
+
dim: 256
|
78 |
+
kernel_size: 7
|
79 |
+
padding: 3
|
80 |
+
layer_scale_init_value: 1e-6
|
81 |
+
use_dwconv: True # depth-wise convs
|
82 |
+
num_layers: 2
|
83 |
+
|
84 |
+
num_maskmem: 7
|
85 |
+
image_size: 1024
|
86 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
87 |
+
sigmoid_scale_for_mem_enc: 20.0
|
88 |
+
sigmoid_bias_for_mem_enc: -10.0
|
89 |
+
use_mask_input_as_output_without_sam: true
|
90 |
+
# Memory
|
91 |
+
directly_add_no_mem_embed: true
|
92 |
+
no_obj_embed_spatial: true
|
93 |
+
# use high-resolution feature map in the SAM mask decoder
|
94 |
+
use_high_res_features_in_sam: true
|
95 |
+
# output 3 masks on the first click on initial conditioning frames
|
96 |
+
multimask_output_in_sam: true
|
97 |
+
# SAM heads
|
98 |
+
iou_prediction_use_sigmoid: True
|
99 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
100 |
+
use_obj_ptrs_in_encoder: true
|
101 |
+
add_tpos_enc_to_obj_ptrs: true
|
102 |
+
proj_tpos_enc_in_obj_ptrs: true
|
103 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
104 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
105 |
+
# object occlusion prediction
|
106 |
+
pred_obj_scores: true
|
107 |
+
pred_obj_scores_mlp: true
|
108 |
+
fixed_no_obj_ptr: true
|
109 |
+
# multimask tracking settings
|
110 |
+
multimask_output_for_tracking: true
|
111 |
+
use_multimask_token_for_obj_ptr: true
|
112 |
+
multimask_min_pt_num: 0
|
113 |
+
multimask_max_pt_num: 1
|
114 |
+
use_mlp_for_obj_ptr_proj: true
|
115 |
+
# Compilation flag
|
116 |
+
compile_image_encoder: False
|
eval/grounded_sam/sam2/configs/sam2.1/sam2.1_hiera_l.yaml
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 144
|
12 |
+
num_heads: 2
|
13 |
+
stages: [2, 6, 36, 4]
|
14 |
+
global_att_blocks: [23, 33, 43]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
window_spec: [8, 4, 16, 8]
|
17 |
+
neck:
|
18 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
19 |
+
position_encoding:
|
20 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
21 |
+
num_pos_feats: 256
|
22 |
+
normalize: true
|
23 |
+
scale: null
|
24 |
+
temperature: 10000
|
25 |
+
d_model: 256
|
26 |
+
backbone_channel_list: [1152, 576, 288, 144]
|
27 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
28 |
+
fpn_interp_model: nearest
|
29 |
+
|
30 |
+
memory_attention:
|
31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
32 |
+
d_model: 256
|
33 |
+
pos_enc_at_input: true
|
34 |
+
layer:
|
35 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
36 |
+
activation: relu
|
37 |
+
dim_feedforward: 2048
|
38 |
+
dropout: 0.1
|
39 |
+
pos_enc_at_attn: false
|
40 |
+
self_attention:
|
41 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
42 |
+
rope_theta: 10000.0
|
43 |
+
feat_sizes: [32, 32]
|
44 |
+
embedding_dim: 256
|
45 |
+
num_heads: 1
|
46 |
+
downsample_rate: 1
|
47 |
+
dropout: 0.1
|
48 |
+
d_model: 256
|
49 |
+
pos_enc_at_cross_attn_keys: true
|
50 |
+
pos_enc_at_cross_attn_queries: false
|
51 |
+
cross_attention:
|
52 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
53 |
+
rope_theta: 10000.0
|
54 |
+
feat_sizes: [32, 32]
|
55 |
+
rope_k_repeat: True
|
56 |
+
embedding_dim: 256
|
57 |
+
num_heads: 1
|
58 |
+
downsample_rate: 1
|
59 |
+
dropout: 0.1
|
60 |
+
kv_in_dim: 64
|
61 |
+
num_layers: 4
|
62 |
+
|
63 |
+
memory_encoder:
|
64 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
65 |
+
out_dim: 64
|
66 |
+
position_encoding:
|
67 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
68 |
+
num_pos_feats: 64
|
69 |
+
normalize: true
|
70 |
+
scale: null
|
71 |
+
temperature: 10000
|
72 |
+
mask_downsampler:
|
73 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
74 |
+
kernel_size: 3
|
75 |
+
stride: 2
|
76 |
+
padding: 1
|
77 |
+
fuser:
|
78 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
79 |
+
layer:
|
80 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
81 |
+
dim: 256
|
82 |
+
kernel_size: 7
|
83 |
+
padding: 3
|
84 |
+
layer_scale_init_value: 1e-6
|
85 |
+
use_dwconv: True # depth-wise convs
|
86 |
+
num_layers: 2
|
87 |
+
|
88 |
+
num_maskmem: 7
|
89 |
+
image_size: 1024
|
90 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
93 |
+
use_mask_input_as_output_without_sam: true
|
94 |
+
# Memory
|
95 |
+
directly_add_no_mem_embed: true
|
96 |
+
no_obj_embed_spatial: true
|
97 |
+
# use high-resolution feature map in the SAM mask decoder
|
98 |
+
use_high_res_features_in_sam: true
|
99 |
+
# output 3 masks on the first click on initial conditioning frames
|
100 |
+
multimask_output_in_sam: true
|
101 |
+
# SAM heads
|
102 |
+
iou_prediction_use_sigmoid: True
|
103 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
104 |
+
use_obj_ptrs_in_encoder: true
|
105 |
+
add_tpos_enc_to_obj_ptrs: true
|
106 |
+
proj_tpos_enc_in_obj_ptrs: true
|
107 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
108 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
109 |
+
# object occlusion prediction
|
110 |
+
pred_obj_scores: true
|
111 |
+
pred_obj_scores_mlp: true
|
112 |
+
fixed_no_obj_ptr: true
|
113 |
+
# multimask tracking settings
|
114 |
+
multimask_output_for_tracking: true
|
115 |
+
use_multimask_token_for_obj_ptr: true
|
116 |
+
multimask_min_pt_num: 0
|
117 |
+
multimask_max_pt_num: 1
|
118 |
+
use_mlp_for_obj_ptr_proj: true
|
119 |
+
# Compilation flag
|
120 |
+
compile_image_encoder: False
|
eval/grounded_sam/sam2/configs/sam2.1/sam2.1_hiera_s.yaml
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 96
|
12 |
+
num_heads: 1
|
13 |
+
stages: [1, 2, 11, 2]
|
14 |
+
global_att_blocks: [7, 10, 13]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
neck:
|
17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
18 |
+
position_encoding:
|
19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
20 |
+
num_pos_feats: 256
|
21 |
+
normalize: true
|
22 |
+
scale: null
|
23 |
+
temperature: 10000
|
24 |
+
d_model: 256
|
25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
27 |
+
fpn_interp_model: nearest
|
28 |
+
|
29 |
+
memory_attention:
|
30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
31 |
+
d_model: 256
|
32 |
+
pos_enc_at_input: true
|
33 |
+
layer:
|
34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
35 |
+
activation: relu
|
36 |
+
dim_feedforward: 2048
|
37 |
+
dropout: 0.1
|
38 |
+
pos_enc_at_attn: false
|
39 |
+
self_attention:
|
40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
41 |
+
rope_theta: 10000.0
|
42 |
+
feat_sizes: [32, 32]
|
43 |
+
embedding_dim: 256
|
44 |
+
num_heads: 1
|
45 |
+
downsample_rate: 1
|
46 |
+
dropout: 0.1
|
47 |
+
d_model: 256
|
48 |
+
pos_enc_at_cross_attn_keys: true
|
49 |
+
pos_enc_at_cross_attn_queries: false
|
50 |
+
cross_attention:
|
51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
52 |
+
rope_theta: 10000.0
|
53 |
+
feat_sizes: [32, 32]
|
54 |
+
rope_k_repeat: True
|
55 |
+
embedding_dim: 256
|
56 |
+
num_heads: 1
|
57 |
+
downsample_rate: 1
|
58 |
+
dropout: 0.1
|
59 |
+
kv_in_dim: 64
|
60 |
+
num_layers: 4
|
61 |
+
|
62 |
+
memory_encoder:
|
63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
64 |
+
out_dim: 64
|
65 |
+
position_encoding:
|
66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
67 |
+
num_pos_feats: 64
|
68 |
+
normalize: true
|
69 |
+
scale: null
|
70 |
+
temperature: 10000
|
71 |
+
mask_downsampler:
|
72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
73 |
+
kernel_size: 3
|
74 |
+
stride: 2
|
75 |
+
padding: 1
|
76 |
+
fuser:
|
77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
78 |
+
layer:
|
79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
80 |
+
dim: 256
|
81 |
+
kernel_size: 7
|
82 |
+
padding: 3
|
83 |
+
layer_scale_init_value: 1e-6
|
84 |
+
use_dwconv: True # depth-wise convs
|
85 |
+
num_layers: 2
|
86 |
+
|
87 |
+
num_maskmem: 7
|
88 |
+
image_size: 1024
|
89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
90 |
+
sigmoid_scale_for_mem_enc: 20.0
|
91 |
+
sigmoid_bias_for_mem_enc: -10.0
|
92 |
+
use_mask_input_as_output_without_sam: true
|
93 |
+
# Memory
|
94 |
+
directly_add_no_mem_embed: true
|
95 |
+
no_obj_embed_spatial: true
|
96 |
+
# use high-resolution feature map in the SAM mask decoder
|
97 |
+
use_high_res_features_in_sam: true
|
98 |
+
# output 3 masks on the first click on initial conditioning frames
|
99 |
+
multimask_output_in_sam: true
|
100 |
+
# SAM heads
|
101 |
+
iou_prediction_use_sigmoid: True
|
102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
103 |
+
use_obj_ptrs_in_encoder: true
|
104 |
+
add_tpos_enc_to_obj_ptrs: true
|
105 |
+
proj_tpos_enc_in_obj_ptrs: true
|
106 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
107 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
108 |
+
# object occlusion prediction
|
109 |
+
pred_obj_scores: true
|
110 |
+
pred_obj_scores_mlp: true
|
111 |
+
fixed_no_obj_ptr: true
|
112 |
+
# multimask tracking settings
|
113 |
+
multimask_output_for_tracking: true
|
114 |
+
use_multimask_token_for_obj_ptr: true
|
115 |
+
multimask_min_pt_num: 0
|
116 |
+
multimask_max_pt_num: 1
|
117 |
+
use_mlp_for_obj_ptr_proj: true
|
118 |
+
# Compilation flag
|
119 |
+
compile_image_encoder: False
|
eval/grounded_sam/sam2/configs/sam2.1/sam2.1_hiera_t.yaml
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 96
|
12 |
+
num_heads: 1
|
13 |
+
stages: [1, 2, 7, 2]
|
14 |
+
global_att_blocks: [5, 7, 9]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
neck:
|
17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
18 |
+
position_encoding:
|
19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
20 |
+
num_pos_feats: 256
|
21 |
+
normalize: true
|
22 |
+
scale: null
|
23 |
+
temperature: 10000
|
24 |
+
d_model: 256
|
25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
27 |
+
fpn_interp_model: nearest
|
28 |
+
|
29 |
+
memory_attention:
|
30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
31 |
+
d_model: 256
|
32 |
+
pos_enc_at_input: true
|
33 |
+
layer:
|
34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
35 |
+
activation: relu
|
36 |
+
dim_feedforward: 2048
|
37 |
+
dropout: 0.1
|
38 |
+
pos_enc_at_attn: false
|
39 |
+
self_attention:
|
40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
41 |
+
rope_theta: 10000.0
|
42 |
+
feat_sizes: [32, 32]
|
43 |
+
embedding_dim: 256
|
44 |
+
num_heads: 1
|
45 |
+
downsample_rate: 1
|
46 |
+
dropout: 0.1
|
47 |
+
d_model: 256
|
48 |
+
pos_enc_at_cross_attn_keys: true
|
49 |
+
pos_enc_at_cross_attn_queries: false
|
50 |
+
cross_attention:
|
51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
52 |
+
rope_theta: 10000.0
|
53 |
+
feat_sizes: [32, 32]
|
54 |
+
rope_k_repeat: True
|
55 |
+
embedding_dim: 256
|
56 |
+
num_heads: 1
|
57 |
+
downsample_rate: 1
|
58 |
+
dropout: 0.1
|
59 |
+
kv_in_dim: 64
|
60 |
+
num_layers: 4
|
61 |
+
|
62 |
+
memory_encoder:
|
63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
64 |
+
out_dim: 64
|
65 |
+
position_encoding:
|
66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
67 |
+
num_pos_feats: 64
|
68 |
+
normalize: true
|
69 |
+
scale: null
|
70 |
+
temperature: 10000
|
71 |
+
mask_downsampler:
|
72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
73 |
+
kernel_size: 3
|
74 |
+
stride: 2
|
75 |
+
padding: 1
|
76 |
+
fuser:
|
77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
78 |
+
layer:
|
79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
80 |
+
dim: 256
|
81 |
+
kernel_size: 7
|
82 |
+
padding: 3
|
83 |
+
layer_scale_init_value: 1e-6
|
84 |
+
use_dwconv: True # depth-wise convs
|
85 |
+
num_layers: 2
|
86 |
+
|
87 |
+
num_maskmem: 7
|
88 |
+
image_size: 1024
|
89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
90 |
+
# SAM decoder
|
91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
93 |
+
use_mask_input_as_output_without_sam: true
|
94 |
+
# Memory
|
95 |
+
directly_add_no_mem_embed: true
|
96 |
+
no_obj_embed_spatial: true
|
97 |
+
# use high-resolution feature map in the SAM mask decoder
|
98 |
+
use_high_res_features_in_sam: true
|
99 |
+
# output 3 masks on the first click on initial conditioning frames
|
100 |
+
multimask_output_in_sam: true
|
101 |
+
# SAM heads
|
102 |
+
iou_prediction_use_sigmoid: True
|
103 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
104 |
+
use_obj_ptrs_in_encoder: true
|
105 |
+
add_tpos_enc_to_obj_ptrs: true
|
106 |
+
proj_tpos_enc_in_obj_ptrs: true
|
107 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
108 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
109 |
+
# object occlusion prediction
|
110 |
+
pred_obj_scores: true
|
111 |
+
pred_obj_scores_mlp: true
|
112 |
+
fixed_no_obj_ptr: true
|
113 |
+
# multimask tracking settings
|
114 |
+
multimask_output_for_tracking: true
|
115 |
+
use_multimask_token_for_obj_ptr: true
|
116 |
+
multimask_min_pt_num: 0
|
117 |
+
multimask_max_pt_num: 1
|
118 |
+
use_mlp_for_obj_ptr_proj: true
|
119 |
+
# Compilation flag
|
120 |
+
# HieraT does not currently support compilation, should always be set to False
|
121 |
+
compile_image_encoder: False
|
eval/grounded_sam/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
scratch:
|
4 |
+
resolution: 1024
|
5 |
+
train_batch_size: 1
|
6 |
+
num_train_workers: 10
|
7 |
+
num_frames: 8
|
8 |
+
max_num_objects: 3
|
9 |
+
base_lr: 5.0e-6
|
10 |
+
vision_lr: 3.0e-06
|
11 |
+
phases_per_epoch: 1
|
12 |
+
num_epochs: 40
|
13 |
+
|
14 |
+
dataset:
|
15 |
+
# PATHS to Dataset
|
16 |
+
img_folder: null # PATH to MOSE JPEGImages folder
|
17 |
+
gt_folder: null # PATH to MOSE Annotations folder
|
18 |
+
file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
|
19 |
+
multiplier: 2
|
20 |
+
|
21 |
+
# Video transforms
|
22 |
+
vos:
|
23 |
+
train_transforms:
|
24 |
+
- _target_: training.dataset.transforms.ComposeAPI
|
25 |
+
transforms:
|
26 |
+
- _target_: training.dataset.transforms.RandomHorizontalFlip
|
27 |
+
consistent_transform: True
|
28 |
+
- _target_: training.dataset.transforms.RandomAffine
|
29 |
+
degrees: 25
|
30 |
+
shear: 20
|
31 |
+
image_interpolation: bilinear
|
32 |
+
consistent_transform: True
|
33 |
+
- _target_: training.dataset.transforms.RandomResizeAPI
|
34 |
+
sizes: ${scratch.resolution}
|
35 |
+
square: true
|
36 |
+
consistent_transform: True
|
37 |
+
- _target_: training.dataset.transforms.ColorJitter
|
38 |
+
consistent_transform: True
|
39 |
+
brightness: 0.1
|
40 |
+
contrast: 0.03
|
41 |
+
saturation: 0.03
|
42 |
+
hue: null
|
43 |
+
- _target_: training.dataset.transforms.RandomGrayscale
|
44 |
+
p: 0.05
|
45 |
+
consistent_transform: True
|
46 |
+
- _target_: training.dataset.transforms.ColorJitter
|
47 |
+
consistent_transform: False
|
48 |
+
brightness: 0.1
|
49 |
+
contrast: 0.05
|
50 |
+
saturation: 0.05
|
51 |
+
hue: null
|
52 |
+
- _target_: training.dataset.transforms.ToTensorAPI
|
53 |
+
- _target_: training.dataset.transforms.NormalizeAPI
|
54 |
+
mean: [0.485, 0.456, 0.406]
|
55 |
+
std: [0.229, 0.224, 0.225]
|
56 |
+
|
57 |
+
trainer:
|
58 |
+
_target_: training.trainer.Trainer
|
59 |
+
mode: train_only
|
60 |
+
max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
|
61 |
+
accelerator: cuda
|
62 |
+
seed_value: 123
|
63 |
+
|
64 |
+
model:
|
65 |
+
_target_: training.model.sam2.SAM2Train
|
66 |
+
image_encoder:
|
67 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
68 |
+
scalp: 1
|
69 |
+
trunk:
|
70 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
71 |
+
embed_dim: 112
|
72 |
+
num_heads: 2
|
73 |
+
drop_path_rate: 0.1
|
74 |
+
neck:
|
75 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
76 |
+
position_encoding:
|
77 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
78 |
+
num_pos_feats: 256
|
79 |
+
normalize: true
|
80 |
+
scale: null
|
81 |
+
temperature: 10000
|
82 |
+
d_model: 256
|
83 |
+
backbone_channel_list: [896, 448, 224, 112]
|
84 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
85 |
+
fpn_interp_model: nearest
|
86 |
+
|
87 |
+
memory_attention:
|
88 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
89 |
+
d_model: 256
|
90 |
+
pos_enc_at_input: true
|
91 |
+
layer:
|
92 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
93 |
+
activation: relu
|
94 |
+
dim_feedforward: 2048
|
95 |
+
dropout: 0.1
|
96 |
+
pos_enc_at_attn: false
|
97 |
+
self_attention:
|
98 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
99 |
+
rope_theta: 10000.0
|
100 |
+
feat_sizes: [32, 32]
|
101 |
+
embedding_dim: 256
|
102 |
+
num_heads: 1
|
103 |
+
downsample_rate: 1
|
104 |
+
dropout: 0.1
|
105 |
+
d_model: 256
|
106 |
+
pos_enc_at_cross_attn_keys: true
|
107 |
+
pos_enc_at_cross_attn_queries: false
|
108 |
+
cross_attention:
|
109 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
110 |
+
rope_theta: 10000.0
|
111 |
+
feat_sizes: [32, 32]
|
112 |
+
rope_k_repeat: True
|
113 |
+
embedding_dim: 256
|
114 |
+
num_heads: 1
|
115 |
+
downsample_rate: 1
|
116 |
+
dropout: 0.1
|
117 |
+
kv_in_dim: 64
|
118 |
+
num_layers: 4
|
119 |
+
|
120 |
+
memory_encoder:
|
121 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
122 |
+
out_dim: 64
|
123 |
+
position_encoding:
|
124 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
125 |
+
num_pos_feats: 64
|
126 |
+
normalize: true
|
127 |
+
scale: null
|
128 |
+
temperature: 10000
|
129 |
+
mask_downsampler:
|
130 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
131 |
+
kernel_size: 3
|
132 |
+
stride: 2
|
133 |
+
padding: 1
|
134 |
+
fuser:
|
135 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
136 |
+
layer:
|
137 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
138 |
+
dim: 256
|
139 |
+
kernel_size: 7
|
140 |
+
padding: 3
|
141 |
+
layer_scale_init_value: 1e-6
|
142 |
+
use_dwconv: True # depth-wise convs
|
143 |
+
num_layers: 2
|
144 |
+
|
145 |
+
num_maskmem: 7
|
146 |
+
image_size: ${scratch.resolution}
|
147 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
148 |
+
sigmoid_scale_for_mem_enc: 20.0
|
149 |
+
sigmoid_bias_for_mem_enc: -10.0
|
150 |
+
use_mask_input_as_output_without_sam: true
|
151 |
+
# Memory
|
152 |
+
directly_add_no_mem_embed: true
|
153 |
+
no_obj_embed_spatial: true
|
154 |
+
# use high-resolution feature map in the SAM mask decoder
|
155 |
+
use_high_res_features_in_sam: true
|
156 |
+
# output 3 masks on the first click on initial conditioning frames
|
157 |
+
multimask_output_in_sam: true
|
158 |
+
# SAM heads
|
159 |
+
iou_prediction_use_sigmoid: True
|
160 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
161 |
+
use_obj_ptrs_in_encoder: true
|
162 |
+
add_tpos_enc_to_obj_ptrs: true
|
163 |
+
proj_tpos_enc_in_obj_ptrs: true
|
164 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
165 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
166 |
+
# object occlusion prediction
|
167 |
+
pred_obj_scores: true
|
168 |
+
pred_obj_scores_mlp: true
|
169 |
+
fixed_no_obj_ptr: true
|
170 |
+
# multimask tracking settings
|
171 |
+
multimask_output_for_tracking: true
|
172 |
+
use_multimask_token_for_obj_ptr: true
|
173 |
+
multimask_min_pt_num: 0
|
174 |
+
multimask_max_pt_num: 1
|
175 |
+
use_mlp_for_obj_ptr_proj: true
|
176 |
+
# Compilation flag
|
177 |
+
# compile_image_encoder: False
|
178 |
+
|
179 |
+
####### Training specific params #######
|
180 |
+
# box/point input and corrections
|
181 |
+
prob_to_use_pt_input_for_train: 0.5
|
182 |
+
prob_to_use_pt_input_for_eval: 0.0
|
183 |
+
prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
|
184 |
+
prob_to_use_box_input_for_eval: 0.0
|
185 |
+
prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
|
186 |
+
num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
|
187 |
+
num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
|
188 |
+
rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
|
189 |
+
add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
|
190 |
+
# maximum 2 initial conditioning frames
|
191 |
+
num_init_cond_frames_for_train: 2
|
192 |
+
rand_init_cond_frames_for_train: True # random 1~2
|
193 |
+
num_correction_pt_per_frame: 7
|
194 |
+
use_act_ckpt_iterative_pt_sampling: false
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
num_init_cond_frames_for_eval: 1 # only mask on the first frame
|
199 |
+
forward_backbone_per_frame_for_eval: True
|
200 |
+
|
201 |
+
|
202 |
+
data:
|
203 |
+
train:
|
204 |
+
_target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
|
205 |
+
phases_per_epoch: ${scratch.phases_per_epoch}
|
206 |
+
batch_sizes:
|
207 |
+
- ${scratch.train_batch_size}
|
208 |
+
|
209 |
+
datasets:
|
210 |
+
- _target_: training.dataset.utils.RepeatFactorWrapper
|
211 |
+
dataset:
|
212 |
+
_target_: training.dataset.utils.ConcatDataset
|
213 |
+
datasets:
|
214 |
+
- _target_: training.dataset.vos_dataset.VOSDataset
|
215 |
+
transforms: ${vos.train_transforms}
|
216 |
+
training: true
|
217 |
+
video_dataset:
|
218 |
+
_target_: training.dataset.vos_raw_dataset.PNGRawDataset
|
219 |
+
img_folder: ${dataset.img_folder}
|
220 |
+
gt_folder: ${dataset.gt_folder}
|
221 |
+
file_list_txt: ${dataset.file_list_txt}
|
222 |
+
sampler:
|
223 |
+
_target_: training.dataset.vos_sampler.RandomUniformSampler
|
224 |
+
num_frames: ${scratch.num_frames}
|
225 |
+
max_num_objects: ${scratch.max_num_objects}
|
226 |
+
multiplier: ${dataset.multiplier}
|
227 |
+
shuffle: True
|
228 |
+
num_workers: ${scratch.num_train_workers}
|
229 |
+
pin_memory: True
|
230 |
+
drop_last: True
|
231 |
+
collate_fn:
|
232 |
+
_target_: training.utils.data_utils.collate_fn
|
233 |
+
_partial_: true
|
234 |
+
dict_key: all
|
235 |
+
|
236 |
+
optim:
|
237 |
+
amp:
|
238 |
+
enabled: True
|
239 |
+
amp_dtype: bfloat16
|
240 |
+
|
241 |
+
optimizer:
|
242 |
+
_target_: torch.optim.AdamW
|
243 |
+
|
244 |
+
gradient_clip:
|
245 |
+
_target_: training.optimizer.GradientClipper
|
246 |
+
max_norm: 0.1
|
247 |
+
norm_type: 2
|
248 |
+
|
249 |
+
param_group_modifiers:
|
250 |
+
- _target_: training.optimizer.layer_decay_param_modifier
|
251 |
+
_partial_: True
|
252 |
+
layer_decay_value: 0.9
|
253 |
+
apply_to: 'image_encoder.trunk'
|
254 |
+
overrides:
|
255 |
+
- pattern: '*pos_embed*'
|
256 |
+
value: 1.0
|
257 |
+
|
258 |
+
options:
|
259 |
+
lr:
|
260 |
+
- scheduler:
|
261 |
+
_target_: fvcore.common.param_scheduler.CosineParamScheduler
|
262 |
+
start_value: ${scratch.base_lr}
|
263 |
+
end_value: ${divide:${scratch.base_lr},10}
|
264 |
+
- scheduler:
|
265 |
+
_target_: fvcore.common.param_scheduler.CosineParamScheduler
|
266 |
+
start_value: ${scratch.vision_lr}
|
267 |
+
end_value: ${divide:${scratch.vision_lr},10}
|
268 |
+
param_names:
|
269 |
+
- 'image_encoder.*'
|
270 |
+
weight_decay:
|
271 |
+
- scheduler:
|
272 |
+
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
273 |
+
value: 0.1
|
274 |
+
- scheduler:
|
275 |
+
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
276 |
+
value: 0.0
|
277 |
+
param_names:
|
278 |
+
- '*bias*'
|
279 |
+
module_cls_names: ['torch.nn.LayerNorm']
|
280 |
+
|
281 |
+
loss:
|
282 |
+
all:
|
283 |
+
_target_: training.loss_fns.MultiStepMultiMasksAndIous
|
284 |
+
weight_dict:
|
285 |
+
loss_mask: 20
|
286 |
+
loss_dice: 1
|
287 |
+
loss_iou: 1
|
288 |
+
loss_class: 1
|
289 |
+
supervise_all_iou: true
|
290 |
+
iou_use_l1_loss: true
|
291 |
+
pred_obj_scores: true
|
292 |
+
focal_gamma_obj_score: 0.0
|
293 |
+
focal_alpha_obj_score: -1.0
|
294 |
+
|
295 |
+
distributed:
|
296 |
+
backend: nccl
|
297 |
+
find_unused_parameters: True
|
298 |
+
|
299 |
+
logging:
|
300 |
+
tensorboard_writer:
|
301 |
+
_target_: training.utils.logger.make_tensorboard_logger
|
302 |
+
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
303 |
+
flush_secs: 120
|
304 |
+
should_log: True
|
305 |
+
log_dir: ${launcher.experiment_log_dir}/logs
|
306 |
+
log_freq: 10
|
307 |
+
|
308 |
+
# initialize from a SAM 2 checkpoint
|
309 |
+
checkpoint:
|
310 |
+
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
311 |
+
save_freq: 0 # 0 only last checkpoint is saved.
|
312 |
+
model_weight_initializer:
|
313 |
+
_partial_: True
|
314 |
+
_target_: training.utils.checkpoint_utils.load_state_dict_into_model
|
315 |
+
strict: True
|
316 |
+
ignore_unexpected_keys: null
|
317 |
+
ignore_missing_keys: null
|
318 |
+
|
319 |
+
state_dict:
|
320 |
+
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
|
321 |
+
checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
|
322 |
+
ckpt_state_dict_keys: ['model']
|
323 |
+
|
324 |
+
launcher:
|
325 |
+
num_nodes: 1
|
326 |
+
gpus_per_node: 8
|
327 |
+
experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
|
328 |
+
|
329 |
+
# SLURM args if running on a cluster
|
330 |
+
submitit:
|
331 |
+
partition: null
|
332 |
+
account: null
|
333 |
+
qos: null
|
334 |
+
cpus_per_task: 10
|
335 |
+
use_cluster: false
|
336 |
+
timeout_hour: 24
|
337 |
+
name: null
|
338 |
+
port_range: [10000, 65000]
|
339 |
+
|
eval/grounded_sam/sam2/configs/sam2/sam2_hiera_b+.yaml
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 112
|
12 |
+
num_heads: 2
|
13 |
+
neck:
|
14 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
15 |
+
position_encoding:
|
16 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
17 |
+
num_pos_feats: 256
|
18 |
+
normalize: true
|
19 |
+
scale: null
|
20 |
+
temperature: 10000
|
21 |
+
d_model: 256
|
22 |
+
backbone_channel_list: [896, 448, 224, 112]
|
23 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
24 |
+
fpn_interp_model: nearest
|
25 |
+
|
26 |
+
memory_attention:
|
27 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
28 |
+
d_model: 256
|
29 |
+
pos_enc_at_input: true
|
30 |
+
layer:
|
31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
32 |
+
activation: relu
|
33 |
+
dim_feedforward: 2048
|
34 |
+
dropout: 0.1
|
35 |
+
pos_enc_at_attn: false
|
36 |
+
self_attention:
|
37 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
38 |
+
rope_theta: 10000.0
|
39 |
+
feat_sizes: [32, 32]
|
40 |
+
embedding_dim: 256
|
41 |
+
num_heads: 1
|
42 |
+
downsample_rate: 1
|
43 |
+
dropout: 0.1
|
44 |
+
d_model: 256
|
45 |
+
pos_enc_at_cross_attn_keys: true
|
46 |
+
pos_enc_at_cross_attn_queries: false
|
47 |
+
cross_attention:
|
48 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
49 |
+
rope_theta: 10000.0
|
50 |
+
feat_sizes: [32, 32]
|
51 |
+
rope_k_repeat: True
|
52 |
+
embedding_dim: 256
|
53 |
+
num_heads: 1
|
54 |
+
downsample_rate: 1
|
55 |
+
dropout: 0.1
|
56 |
+
kv_in_dim: 64
|
57 |
+
num_layers: 4
|
58 |
+
|
59 |
+
memory_encoder:
|
60 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
61 |
+
out_dim: 64
|
62 |
+
position_encoding:
|
63 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
64 |
+
num_pos_feats: 64
|
65 |
+
normalize: true
|
66 |
+
scale: null
|
67 |
+
temperature: 10000
|
68 |
+
mask_downsampler:
|
69 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
70 |
+
kernel_size: 3
|
71 |
+
stride: 2
|
72 |
+
padding: 1
|
73 |
+
fuser:
|
74 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
75 |
+
layer:
|
76 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
77 |
+
dim: 256
|
78 |
+
kernel_size: 7
|
79 |
+
padding: 3
|
80 |
+
layer_scale_init_value: 1e-6
|
81 |
+
use_dwconv: True # depth-wise convs
|
82 |
+
num_layers: 2
|
83 |
+
|
84 |
+
num_maskmem: 7
|
85 |
+
image_size: 1024
|
86 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
87 |
+
sigmoid_scale_for_mem_enc: 20.0
|
88 |
+
sigmoid_bias_for_mem_enc: -10.0
|
89 |
+
use_mask_input_as_output_without_sam: true
|
90 |
+
# Memory
|
91 |
+
directly_add_no_mem_embed: true
|
92 |
+
# use high-resolution feature map in the SAM mask decoder
|
93 |
+
use_high_res_features_in_sam: true
|
94 |
+
# output 3 masks on the first click on initial conditioning frames
|
95 |
+
multimask_output_in_sam: true
|
96 |
+
# SAM heads
|
97 |
+
iou_prediction_use_sigmoid: True
|
98 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
99 |
+
use_obj_ptrs_in_encoder: true
|
100 |
+
add_tpos_enc_to_obj_ptrs: false
|
101 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
102 |
+
# object occlusion prediction
|
103 |
+
pred_obj_scores: true
|
104 |
+
pred_obj_scores_mlp: true
|
105 |
+
fixed_no_obj_ptr: true
|
106 |
+
# multimask tracking settings
|
107 |
+
multimask_output_for_tracking: true
|
108 |
+
use_multimask_token_for_obj_ptr: true
|
109 |
+
multimask_min_pt_num: 0
|
110 |
+
multimask_max_pt_num: 1
|
111 |
+
use_mlp_for_obj_ptr_proj: true
|
112 |
+
# Compilation flag
|
113 |
+
compile_image_encoder: False
|
eval/grounded_sam/sam2/configs/sam2/sam2_hiera_l.yaml
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 144
|
12 |
+
num_heads: 2
|
13 |
+
stages: [2, 6, 36, 4]
|
14 |
+
global_att_blocks: [23, 33, 43]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
window_spec: [8, 4, 16, 8]
|
17 |
+
neck:
|
18 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
19 |
+
position_encoding:
|
20 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
21 |
+
num_pos_feats: 256
|
22 |
+
normalize: true
|
23 |
+
scale: null
|
24 |
+
temperature: 10000
|
25 |
+
d_model: 256
|
26 |
+
backbone_channel_list: [1152, 576, 288, 144]
|
27 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
28 |
+
fpn_interp_model: nearest
|
29 |
+
|
30 |
+
memory_attention:
|
31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
32 |
+
d_model: 256
|
33 |
+
pos_enc_at_input: true
|
34 |
+
layer:
|
35 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
36 |
+
activation: relu
|
37 |
+
dim_feedforward: 2048
|
38 |
+
dropout: 0.1
|
39 |
+
pos_enc_at_attn: false
|
40 |
+
self_attention:
|
41 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
42 |
+
rope_theta: 10000.0
|
43 |
+
feat_sizes: [32, 32]
|
44 |
+
embedding_dim: 256
|
45 |
+
num_heads: 1
|
46 |
+
downsample_rate: 1
|
47 |
+
dropout: 0.1
|
48 |
+
d_model: 256
|
49 |
+
pos_enc_at_cross_attn_keys: true
|
50 |
+
pos_enc_at_cross_attn_queries: false
|
51 |
+
cross_attention:
|
52 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
53 |
+
rope_theta: 10000.0
|
54 |
+
feat_sizes: [32, 32]
|
55 |
+
rope_k_repeat: True
|
56 |
+
embedding_dim: 256
|
57 |
+
num_heads: 1
|
58 |
+
downsample_rate: 1
|
59 |
+
dropout: 0.1
|
60 |
+
kv_in_dim: 64
|
61 |
+
num_layers: 4
|
62 |
+
|
63 |
+
memory_encoder:
|
64 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
65 |
+
out_dim: 64
|
66 |
+
position_encoding:
|
67 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
68 |
+
num_pos_feats: 64
|
69 |
+
normalize: true
|
70 |
+
scale: null
|
71 |
+
temperature: 10000
|
72 |
+
mask_downsampler:
|
73 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
74 |
+
kernel_size: 3
|
75 |
+
stride: 2
|
76 |
+
padding: 1
|
77 |
+
fuser:
|
78 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
79 |
+
layer:
|
80 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
81 |
+
dim: 256
|
82 |
+
kernel_size: 7
|
83 |
+
padding: 3
|
84 |
+
layer_scale_init_value: 1e-6
|
85 |
+
use_dwconv: True # depth-wise convs
|
86 |
+
num_layers: 2
|
87 |
+
|
88 |
+
num_maskmem: 7
|
89 |
+
image_size: 1024
|
90 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
93 |
+
use_mask_input_as_output_without_sam: true
|
94 |
+
# Memory
|
95 |
+
directly_add_no_mem_embed: true
|
96 |
+
# use high-resolution feature map in the SAM mask decoder
|
97 |
+
use_high_res_features_in_sam: true
|
98 |
+
# output 3 masks on the first click on initial conditioning frames
|
99 |
+
multimask_output_in_sam: true
|
100 |
+
# SAM heads
|
101 |
+
iou_prediction_use_sigmoid: True
|
102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
103 |
+
use_obj_ptrs_in_encoder: true
|
104 |
+
add_tpos_enc_to_obj_ptrs: false
|
105 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
106 |
+
# object occlusion prediction
|
107 |
+
pred_obj_scores: true
|
108 |
+
pred_obj_scores_mlp: true
|
109 |
+
fixed_no_obj_ptr: true
|
110 |
+
# multimask tracking settings
|
111 |
+
multimask_output_for_tracking: true
|
112 |
+
use_multimask_token_for_obj_ptr: true
|
113 |
+
multimask_min_pt_num: 0
|
114 |
+
multimask_max_pt_num: 1
|
115 |
+
use_mlp_for_obj_ptr_proj: true
|
116 |
+
# Compilation flag
|
117 |
+
compile_image_encoder: False
|