Commit
·
a9d25c7
0
Parent(s):
gradio demo for ZeroGPU, HF
Browse files- .gitattributes +35 -0
- .gitignore +171 -0
- README.md +13 -0
- app.py +121 -0
- configs/COCO-81.yaml +3 -0
- configs/Cityscapes.yaml +3 -0
- configs/DRAM.yaml +3 -0
- configs/VOC2012.yaml +2 -0
- examples/COCO-81_eg.jpg +0 -0
- examples/Cityscapes_eg.jpg +0 -0
- examples/DRAM_eg.jpg +0 -0
- examples/VOC2012_eg.jpg +0 -0
- gradio_cached_examples/16/log.csv +5 -0
- gradio_cached_examples/16/output/6b3896574851c5665d17/image.webp +0 -0
- gradio_cached_examples/16/output/8d66e32b3b15feb7ecc9/image.webp +0 -0
- gradio_cached_examples/16/output/ca6408b417ed4de51f74/image.webp +0 -0
- gradio_cached_examples/16/output/e9085a590715dd9a4cbc/image.webp +0 -0
- pre-requirements.txt +19 -0
- pretrained-models/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py +43 -0
- pretrained-models/checkpoints/groundingdino_swint_ogc.pth +3 -0
- pretrained-models/checkpoints/put pre-trained checkpoints here.txt +0 -0
- pretrained-models/checkpoints/ram_plus_swin_large_14m.pth +3 -0
- pretrained-models/checkpoints/sam_hq_vit_l.pth +3 -0
- requirements.txt +3 -0
- utils/Arial.ttf +0 -0
- utils/blip2_utils.py +40 -0
- utils/env_utils.py +56 -0
- utils/grounded_sam_utils.py +348 -0
- utils/labels_utils.py +214 -0
- utils/llms_utils.py +233 -0
- utils/ram_utils.py +86 -0
- utils/timer_utils.py +89 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
|
30 |
+
*.pth
|
31 |
+
*.bin
|
32 |
+
*.log
|
33 |
+
*.safetensors
|
34 |
+
outputs/
|
35 |
+
outputs_single/
|
36 |
+
results/
|
37 |
+
pretrained-models/checkpoints/
|
38 |
+
|
39 |
+
|
40 |
+
# PyInstaller
|
41 |
+
# Usually these files are written by a python script from a template
|
42 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
43 |
+
*.manifest
|
44 |
+
*.spec
|
45 |
+
|
46 |
+
# Installer logs
|
47 |
+
pip-log.txt
|
48 |
+
pip-delete-this-directory.txt
|
49 |
+
|
50 |
+
# Unit test / coverage reports
|
51 |
+
htmlcov/
|
52 |
+
.tox/
|
53 |
+
.nox/
|
54 |
+
.coverage
|
55 |
+
.coverage.*
|
56 |
+
.cache
|
57 |
+
nosetests.xml
|
58 |
+
coverage.xml
|
59 |
+
*.cover
|
60 |
+
*.py,cover
|
61 |
+
.hypothesis/
|
62 |
+
.pytest_cache/
|
63 |
+
cover/
|
64 |
+
|
65 |
+
# Translations
|
66 |
+
*.mo
|
67 |
+
*.pot
|
68 |
+
|
69 |
+
# Django stuff:
|
70 |
+
*.log
|
71 |
+
local_settings.py
|
72 |
+
db.sqlite3
|
73 |
+
db.sqlite3-journal
|
74 |
+
|
75 |
+
# Flask stuff:
|
76 |
+
instance/
|
77 |
+
.webassets-cache
|
78 |
+
|
79 |
+
# Scrapy stuff:
|
80 |
+
.scrapy
|
81 |
+
|
82 |
+
# Sphinx documentation
|
83 |
+
docs/_build/
|
84 |
+
|
85 |
+
# PyBuilder
|
86 |
+
.pybuilder/
|
87 |
+
target/
|
88 |
+
|
89 |
+
# Jupyter Notebook
|
90 |
+
.ipynb_checkpoints
|
91 |
+
|
92 |
+
# IPython
|
93 |
+
profile_default/
|
94 |
+
ipython_config.py
|
95 |
+
|
96 |
+
# pyenv
|
97 |
+
# For a library or package, you might want to ignore these files since the code is
|
98 |
+
# intended to run in multiple environments; otherwise, check them in:
|
99 |
+
# .python-version
|
100 |
+
|
101 |
+
# pipenv
|
102 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
103 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
104 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
105 |
+
# install all needed dependencies.
|
106 |
+
#Pipfile.lock
|
107 |
+
|
108 |
+
# poetry
|
109 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
110 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
111 |
+
# commonly ignored for libraries.
|
112 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
113 |
+
#poetry.lock
|
114 |
+
|
115 |
+
# pdm
|
116 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
117 |
+
#pdm.lock
|
118 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
119 |
+
# in version control.
|
120 |
+
# https://pdm.fming.dev/#use-with-ide
|
121 |
+
.pdm.toml
|
122 |
+
|
123 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
124 |
+
__pypackages__/
|
125 |
+
|
126 |
+
# Celery stuff
|
127 |
+
celerybeat-schedule
|
128 |
+
celerybeat.pid
|
129 |
+
|
130 |
+
# SageMath parsed files
|
131 |
+
*.sage.py
|
132 |
+
|
133 |
+
# Environments
|
134 |
+
.env
|
135 |
+
.venv
|
136 |
+
env/
|
137 |
+
venv/
|
138 |
+
ENV/
|
139 |
+
env.bak/
|
140 |
+
venv.bak/
|
141 |
+
|
142 |
+
# Spyder project settings
|
143 |
+
.spyderproject
|
144 |
+
.spyproject
|
145 |
+
|
146 |
+
# Rope project settings
|
147 |
+
.ropeproject
|
148 |
+
|
149 |
+
# mkdocs documentation
|
150 |
+
/site
|
151 |
+
|
152 |
+
# mypy
|
153 |
+
.mypy_cache/
|
154 |
+
.dmypy.json
|
155 |
+
dmypy.json
|
156 |
+
|
157 |
+
# Pyre type checker
|
158 |
+
.pyre/
|
159 |
+
|
160 |
+
# pytype static type analyzer
|
161 |
+
.pytype/
|
162 |
+
|
163 |
+
# Cython debug symbols
|
164 |
+
cython_debug/
|
165 |
+
|
166 |
+
# PyCharm
|
167 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
168 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
169 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
170 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
171 |
+
#.idea/
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Training-Free Zero-Shot Semantic Segmentation With LLM Refinement
|
3 |
+
emoji: ⚡
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.38.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: agpl-3.0
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import spaces
|
4 |
+
import gradio as gr
|
5 |
+
from PIL import Image
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
|
8 |
+
# set up environment
|
9 |
+
from utils.env_utils import set_random_seed, use_lower_vram
|
10 |
+
from utils.timer_utils import Timer
|
11 |
+
|
12 |
+
set_random_seed(1024)
|
13 |
+
timer = Timer()
|
14 |
+
timer.start()
|
15 |
+
# use_lower_vram()
|
16 |
+
|
17 |
+
# import functions
|
18 |
+
from utils.labels_utils import Labels
|
19 |
+
from utils.ram_utils import ram_inference
|
20 |
+
from utils.blip2_utils import blip2_caption
|
21 |
+
from utils.llms_utils import pre_refinement, make_prompt, init_model
|
22 |
+
from utils.grounded_sam_utils import run_grounded_sam
|
23 |
+
|
24 |
+
|
25 |
+
# hardcode parameters for G-SAM
|
26 |
+
box_threshold = 0.18
|
27 |
+
text_threshold = 0.15
|
28 |
+
iou_threshold = 0.8
|
29 |
+
|
30 |
+
global current_config, L, llm, system_prompt
|
31 |
+
|
32 |
+
# load Llama-3 here to avoid loading it during the inference.
|
33 |
+
llm = init_model("Meta-Llama-3-8B-Instruct")
|
34 |
+
current_config = ""
|
35 |
+
L = None
|
36 |
+
system_prompt = None
|
37 |
+
|
38 |
+
def load_config(config_type):
|
39 |
+
config = OmegaConf.load(os.path.join(os.path.dirname(__file__), f"configs/{config_type}.yaml"))
|
40 |
+
L = Labels(config=config)
|
41 |
+
# init labels and llm prompt, only Meta-Llama-3-8B-Instruct is supported for online demo, but you can use any model in your local environment using our released code
|
42 |
+
system_prompt = make_prompt(", ".join(L.LABELS))
|
43 |
+
return L, system_prompt
|
44 |
+
|
45 |
+
@spaces.GPU(duration=120)
|
46 |
+
def process(image_ori, config_type):
|
47 |
+
global current_config, L, llm, system_prompt
|
48 |
+
if current_config != config_type:
|
49 |
+
L, system_prompt = load_config(config_type)
|
50 |
+
current_config = config_type
|
51 |
+
else:
|
52 |
+
pass
|
53 |
+
image_ori = cv2.cvtColor(image_ori, cv2.COLOR_BGR2RGB)
|
54 |
+
image_pil = Image.fromarray(image_ori)
|
55 |
+
labels_ram = ram_inference(image_pil) + ": " + blip2_caption(image_pil)
|
56 |
+
converted_labels, llm_output = pre_refinement([labels_ram], system_prompt, llm=llm)
|
57 |
+
labels_llm = L.check_labels(converted_labels)[0]
|
58 |
+
print("labels_ram: ", labels_ram)
|
59 |
+
print("llm_output: ", llm_output)
|
60 |
+
print("labels_llm: ", labels_llm)
|
61 |
+
|
62 |
+
# run sam
|
63 |
+
label_res, bboxes, output_labels, output_prob_maps, output_points = run_grounded_sam(
|
64 |
+
input_image = {"image": image_pil, "mask": None},
|
65 |
+
text_prompt = labels_llm,
|
66 |
+
box_threshold = box_threshold,
|
67 |
+
text_threshold = text_threshold,
|
68 |
+
iou_threshold = iou_threshold,
|
69 |
+
LABELS = L.LABELS,
|
70 |
+
IDS = L.IDS,
|
71 |
+
llm = llm,
|
72 |
+
timer = timer,
|
73 |
+
)
|
74 |
+
|
75 |
+
# draw mask and save image
|
76 |
+
ours = L.draw_mask(label_res, image_ori, print_label=True, tag="Ours")
|
77 |
+
return cv2.cvtColor(ours, cv2.COLOR_BGR2RGB)
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
# options for different settings
|
82 |
+
dropdown_options = ["COCO-81", "Cityscapes", "DRAM", "VOC2012"]
|
83 |
+
default_option = "COCO-81"
|
84 |
+
|
85 |
+
with gr.Blocks() as demo:
|
86 |
+
gr.HTML(
|
87 |
+
"""
|
88 |
+
<h1 style="text-align: center; font-size: 32px; font-family: 'Times New Roman', Times, serif;">
|
89 |
+
Training-Free Zero-Shot Semantic Segmentation with LLM Refinement
|
90 |
+
</h1>
|
91 |
+
<p style="text-align: center; font-size: 20px; font-family: 'Times New Roman', Times, serif;">
|
92 |
+
<a style="text-align: center; display:inline-block"
|
93 |
+
href="https://sky24h.github.io/websites/bmvc2024_training-free-semseg-with-LLM/">
|
94 |
+
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/paper-page-sm.svg#center"
|
95 |
+
alt="Paper Page">
|
96 |
+
</a>
|
97 |
+
<a style="text-align: center; display:inline-block" href="https://huggingface.co/spaces/sky24h/Training-Free_Zero-Shot_Semantic_Segmentation_with_LLM_Refinement?duplicate=true">
|
98 |
+
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg#center" alt="Duplicate Space">
|
99 |
+
</a>
|
100 |
+
</p>
|
101 |
+
"""
|
102 |
+
)
|
103 |
+
gr.Interface(
|
104 |
+
fn=process,
|
105 |
+
inputs=[gr.Image(type="numpy", height="384"), gr.Dropdown(choices=dropdown_options, label="Refinement Type", value=default_option)],
|
106 |
+
outputs="image",
|
107 |
+
description="""<html>
|
108 |
+
<p style="text-align:center;"> This is an online demo for the paper "Training-Free Zero-Shot Semantic Segmentation with LLM Refinement" (BMVC 2024). </p>
|
109 |
+
<p style="text-align:center;"> Uasge: Please select or upload an image and choose a dataset setting for semantic segmentation refinement.</p>
|
110 |
+
</html>""",
|
111 |
+
allow_flagging='never',
|
112 |
+
examples=[
|
113 |
+
["examples/Cityscapes_eg.png", "Cityscapes"],
|
114 |
+
["examples/DRAM_eg.jpg", "DRAM"],
|
115 |
+
["examples/COCO-81_eg.jpg", "COCO-81"],
|
116 |
+
["examples/VOC2012_eg.jpg", "VOC2012"],
|
117 |
+
],
|
118 |
+
cache_examples=True,
|
119 |
+
)
|
120 |
+
|
121 |
+
demo.queue(max_size=10).launch()
|
configs/COCO-81.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
Name: COCO-81
|
2 |
+
label_list: "unlabeled, person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic-light, fire-hydrant, stop-sign, parking-meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag, tie, suitcase, frisbee, skis, snowboard, sports-ball, kite, baseball-bat, baseball-glove, skateboard, surfboard, tennis-racket, bottle, wine-glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot, hot-dog, pizza, donut, cake, chair, couch, potted-plant, bed, dining-table, toilet, tv, laptop, mouse, remote, keyboard, cell-phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy-bear, hair-drier, toothbrush"
|
3 |
+
mask_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
|
configs/Cityscapes.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
Name: Cityscapes
|
2 |
+
label_list: "background, road, sidewalk, building, wall, fence, pole, traffic-light, traffic-sign, tree, terrain, sky, person, rider, car, truck, bus, train, motorcycle, bicycle"
|
3 |
+
mask_ids: [0, 7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
|
configs/DRAM.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
Name: DRAM
|
2 |
+
label_list: "background, bird, boat, bottle, cat, chair, cow, dog, horse, person, potted-plant, sheep"
|
3 |
+
mask_ids: [0, 3, 4, 5, 8, 9, 10, 12, 13, 15, 16, 17]
|
configs/VOC2012.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Name: VOC2012
|
2 |
+
label_list: "background, aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow, dining-table, dog, horse, motorbike, person, potted-plant, sheep, sofa, train, monitor, void"
|
examples/COCO-81_eg.jpg
ADDED
![]() |
examples/Cityscapes_eg.jpg
ADDED
![]() |
examples/DRAM_eg.jpg
ADDED
![]() |
examples/VOC2012_eg.jpg
ADDED
![]() |
gradio_cached_examples/16/log.csv
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
output,flag,username,timestamp
|
2 |
+
"{""path"": ""gradio_cached_examples/16/output/e9085a590715dd9a4cbc/image.webp"", ""url"": ""/file=/tmp/gradio/f17a9230acfa1f7c9d09b85c0c0528e64c5a19ec/image.webp"", ""size"": null, ""orig_name"": ""image.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}",,,2024-07-30 06:54:26.975686
|
3 |
+
"{""path"": ""gradio_cached_examples/16/output/ca6408b417ed4de51f74/image.webp"", ""url"": ""/file=/tmp/gradio/28f694172a8e086c7d12474e78b1e36453357589/image.webp"", ""size"": null, ""orig_name"": ""image.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}",,,2024-07-30 06:54:28.056743
|
4 |
+
"{""path"": ""gradio_cached_examples/16/output/6b3896574851c5665d17/image.webp"", ""url"": ""/file=/tmp/gradio/f2a070d0cd932cc4bf9ddb7f4cca22c01d4d4e37/image.webp"", ""size"": null, ""orig_name"": ""image.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}",,,2024-07-30 06:54:30.868722
|
5 |
+
"{""path"": ""gradio_cached_examples/16/output/8d66e32b3b15feb7ecc9/image.webp"", ""url"": ""/file=/tmp/gradio/9318730ca69938781665675ccbe76d635bc47a2d/image.webp"", ""size"": null, ""orig_name"": ""image.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}",,,2024-07-30 06:54:32.230374
|
gradio_cached_examples/16/output/6b3896574851c5665d17/image.webp
ADDED
![]() |
gradio_cached_examples/16/output/8d66e32b3b15feb7ecc9/image.webp
ADDED
![]() |
gradio_cached_examples/16/output/ca6408b417ed4de51f74/image.webp
ADDED
![]() |
gradio_cached_examples/16/output/e9085a590715dd9a4cbc/image.webp
ADDED
![]() |
pre-requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Install transformers and timm in second stage to avoid error
|
2 |
+
torch==2.3.1 #pip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu121
|
3 |
+
torchvision==0.18.1
|
4 |
+
setuptools==69.5.1
|
5 |
+
gradio==4.38.1
|
6 |
+
openai>=1.0.0
|
7 |
+
opencv_python==4.8.1.78
|
8 |
+
diffusers[torch]==0.29.2
|
9 |
+
termcolor
|
10 |
+
fairscale
|
11 |
+
natsort
|
12 |
+
omegaconf
|
13 |
+
pycocotools
|
14 |
+
matplotlib
|
15 |
+
onnxruntime
|
16 |
+
onnx
|
17 |
+
groundingdino-py
|
18 |
+
segment_anything@git+https://github.com/SysCV/sam-hq.git
|
19 |
+
ram@git+https://github.com/xinyu1205/recognize-anything.git
|
pretrained-models/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
batch_size = 1
|
2 |
+
modelname = "groundingdino"
|
3 |
+
backbone = "swin_T_224_1k"
|
4 |
+
position_embedding = "sine"
|
5 |
+
pe_temperatureH = 20
|
6 |
+
pe_temperatureW = 20
|
7 |
+
return_interm_indices = [1, 2, 3]
|
8 |
+
backbone_freeze_keywords = None
|
9 |
+
enc_layers = 6
|
10 |
+
dec_layers = 6
|
11 |
+
pre_norm = False
|
12 |
+
dim_feedforward = 2048
|
13 |
+
hidden_dim = 256
|
14 |
+
dropout = 0.0
|
15 |
+
nheads = 8
|
16 |
+
num_queries = 900
|
17 |
+
query_dim = 4
|
18 |
+
num_patterns = 0
|
19 |
+
num_feature_levels = 4
|
20 |
+
enc_n_points = 4
|
21 |
+
dec_n_points = 4
|
22 |
+
two_stage_type = "standard"
|
23 |
+
two_stage_bbox_embed_share = False
|
24 |
+
two_stage_class_embed_share = False
|
25 |
+
transformer_activation = "relu"
|
26 |
+
dec_pred_bbox_embed_share = True
|
27 |
+
dn_box_noise_scale = 1.0
|
28 |
+
dn_label_noise_ratio = 0.5
|
29 |
+
dn_label_coef = 1.0
|
30 |
+
dn_bbox_coef = 1.0
|
31 |
+
embed_init_tgt = True
|
32 |
+
dn_labelbook_size = 2000
|
33 |
+
max_text_len = 256
|
34 |
+
text_encoder_type = "bert-base-uncased"
|
35 |
+
use_text_enhancer = True
|
36 |
+
use_fusion_layer = True
|
37 |
+
use_checkpoint = True
|
38 |
+
use_transformer_ckpt = True
|
39 |
+
use_text_cross_attention = True
|
40 |
+
text_dropout = 0.0
|
41 |
+
fusion_dropout = 0.0
|
42 |
+
fusion_droppath = 0.1
|
43 |
+
sub_sentence_present = True
|
pretrained-models/checkpoints/groundingdino_swint_ogc.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3b3ca2563c77c69f651d7bd133e97139c186df06231157a64c507099c52bc799
|
3 |
+
size 693997677
|
pretrained-models/checkpoints/put pre-trained checkpoints here.txt
ADDED
File without changes
|
pretrained-models/checkpoints/ram_plus_swin_large_14m.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:497c178836ba66698ca226c7895317e6e800034be986452dbd2593298d50e87d
|
3 |
+
size 3010210801
|
pretrained-models/checkpoints/sam_hq_vit_l.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e1a6c385d62bf005ded91a54d5ec55c985cfc4103ef89c08d90f39f04934c343
|
3 |
+
size 1254865805
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Somehow, we needed to put transformers after GroundingDINO.
|
2 |
+
transformers==4.42.4
|
3 |
+
timm==1.0.8
|
utils/Arial.ttf
ADDED
Binary file (276 kB). View file
|
|
utils/blip2_utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from transformers import Blip2Processor, Blip2ForConditionalGeneration # , BitsAndBytesConfig
|
4 |
+
from .env_utils import get_device, low_vram_mode
|
5 |
+
|
6 |
+
device = get_device()
|
7 |
+
|
8 |
+
blip2_model_id = "Salesforce/blip2-opt-2.7b" # or replace with your local model path
|
9 |
+
blip2_precision = torch.bfloat16
|
10 |
+
|
11 |
+
# Load BLIP2 model and processor from HuggingFace
|
12 |
+
blip2_processor = Blip2Processor.from_pretrained(blip2_model_id)
|
13 |
+
if low_vram_mode:
|
14 |
+
blip2_model = Blip2ForConditionalGeneration.from_pretrained(
|
15 |
+
blip2_model_id,
|
16 |
+
torch_dtype=blip2_precision,
|
17 |
+
device_map=device,
|
18 |
+
# quantization_config = BitsAndBytesConfig(load_in_8bit=True) if low_vram_mode else None, # ZeroGPU does not support quantization.
|
19 |
+
).eval()
|
20 |
+
else:
|
21 |
+
blip2_model = Blip2ForConditionalGeneration.from_pretrained(blip2_model_id, torch_dtype=blip2_precision, device_map=device).eval()
|
22 |
+
|
23 |
+
|
24 |
+
def blip2_caption(raw_image):
|
25 |
+
# unconditional image captioning
|
26 |
+
inputs = blip2_processor(raw_image, return_tensors="pt")
|
27 |
+
inputs = inputs.to(device=device, dtype=blip2_precision)
|
28 |
+
out = blip2_model.generate(**inputs)
|
29 |
+
caption = blip2_processor.decode(out[0], skip_special_tokens=True)
|
30 |
+
return caption
|
31 |
+
|
32 |
+
|
33 |
+
# if __name__ == "__main__":
|
34 |
+
# from PIL import Image
|
35 |
+
|
36 |
+
# # Test the RAM++ model
|
37 |
+
# image_path = os.path.join(os.path.dirname(__file__), "../sources/test_imgs/1.jpg")
|
38 |
+
# image = Image.open(image_path)
|
39 |
+
# result = blip2_caption(image)
|
40 |
+
# print(result)
|
utils/env_utils.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Avoid multiple imports of the same module. Use this to import the module only once.
|
2 |
+
# Also, ensure that the device and pretrained models folder are consistent across the project.
|
3 |
+
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
|
7 |
+
global low_vram_mode
|
8 |
+
low_vram_mode = False
|
9 |
+
|
10 |
+
|
11 |
+
def use_lower_vram():
|
12 |
+
global low_vram_mode
|
13 |
+
low_vram_mode = True
|
14 |
+
|
15 |
+
|
16 |
+
def get_device():
|
17 |
+
device = torch.device("cuda") # must use GPU in online demo version
|
18 |
+
return device
|
19 |
+
|
20 |
+
|
21 |
+
def set_random_seed(seed: int):
|
22 |
+
torch.manual_seed(seed)
|
23 |
+
torch.cuda.manual_seed(seed)
|
24 |
+
torch.cuda.manual_seed_all(seed)
|
25 |
+
torch.backends.cudnn.deterministic = True
|
26 |
+
torch.backends.cudnn.benchmark = False
|
27 |
+
|
28 |
+
|
29 |
+
def get_pretrained_models_folder():
|
30 |
+
return os.path.join(os.path.dirname(__file__), "../pretrained-models")
|
31 |
+
|
32 |
+
|
33 |
+
# def download_pretrained_models():
|
34 |
+
# pretrained_models_folder = get_pretrained_models_folder()
|
35 |
+
# # hard-coded download links
|
36 |
+
# groundingdino_link = "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth"
|
37 |
+
# sam_link = "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth"
|
38 |
+
# ram_link = "https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth"
|
39 |
+
# groundingdino_ckpt = os.path.join(pretrained_models_folder, "checkpoints/groundingdino_swint_ogc.pth")
|
40 |
+
# sam_ckpt = os.path.join(pretrained_models_folder, "checkpoints/sam_hq_vit_l.pth")
|
41 |
+
# ram_ckpt = os.path.join(pretrained_models_folder, "checkpoints/ram_plus_swin_large_14m.pth")
|
42 |
+
|
43 |
+
# # download pretrained models if not exists
|
44 |
+
# if not os.path.exists(groundingdino_ckpt):
|
45 |
+
# print(f"Downloading pretrained model: {groundingdino_ckpt}")
|
46 |
+
# os.system(f"wget -O {groundingdino_ckpt} {groundingdino_link} -q")
|
47 |
+
# if not os.path.exists(sam_ckpt):
|
48 |
+
# print(f"Downloading pretrained model: {sam_ckpt}")
|
49 |
+
# os.system(f"wget -O {sam_ckpt} {sam_link} -q")
|
50 |
+
# if not os.path.exists(ram_ckpt):
|
51 |
+
# print(f"Downloading pretrained model: {ram_ckpt}")
|
52 |
+
# os.system(f"wget -O {ram_ckpt} {ram_link} -q")
|
53 |
+
|
54 |
+
|
55 |
+
# # download pretrained models when imported
|
56 |
+
# download_pretrained_models()
|
utils/grounded_sam_utils.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import torchvision
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image, ImageFont
|
7 |
+
import traceback
|
8 |
+
|
9 |
+
# environment variables and paths
|
10 |
+
from .env_utils import get_device, get_pretrained_models_folder, low_vram_mode
|
11 |
+
|
12 |
+
device = get_device()
|
13 |
+
pretrained_models_folder = get_pretrained_models_folder()
|
14 |
+
groundingdino_ckpt = os.path.join(pretrained_models_folder, "checkpoints/groundingdino_swint_ogc.pth")
|
15 |
+
sam_ckpt = os.path.join(pretrained_models_folder, "checkpoints/sam_hq_vit_l.pth")
|
16 |
+
|
17 |
+
# segment anything
|
18 |
+
from segment_anything import build_sam_vit_l, SamPredictor
|
19 |
+
|
20 |
+
# Grounding DINO
|
21 |
+
from groundingdino.models import build_model
|
22 |
+
from groundingdino.util.slconfig import SLConfig
|
23 |
+
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
24 |
+
import groundingdino.datasets.transforms as T
|
25 |
+
|
26 |
+
font_family = os.path.join(os.path.dirname(__file__), "Arial.ttf")
|
27 |
+
font_size = 24
|
28 |
+
font = ImageFont.truetype(font_family, font_size)
|
29 |
+
|
30 |
+
from .llms_utils import post_refinement
|
31 |
+
|
32 |
+
|
33 |
+
def draw_bboxes(ours_bboxes, output_labels, bboxes, output_points, output_prob_maps):
|
34 |
+
# draw bboxes on the image
|
35 |
+
for label, bbox in zip(output_labels, bboxes):
|
36 |
+
bbox = bbox.cpu().numpy()
|
37 |
+
bbox = [int(round(bbox[0])), int(round(bbox[1])), int(round(bbox[2])), int(round(bbox[3]))]
|
38 |
+
# print("label, bbox", label, bbox)
|
39 |
+
cv2.rectangle(ours_bboxes, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
|
40 |
+
# caption inside the bbox, below the top left corner 20 pixels
|
41 |
+
cv2.putText(ours_bboxes, label, (bbox[0], bbox[1] + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
42 |
+
try:
|
43 |
+
for points in output_points:
|
44 |
+
for point in points:
|
45 |
+
# draw a cross on the point
|
46 |
+
cv2.drawMarker(ours_bboxes, (int(point[0]), int(point[1])), (0, 0, 255), cv2.MARKER_CROSS, 10, 2)
|
47 |
+
except: # noqa
|
48 |
+
pass
|
49 |
+
|
50 |
+
# Draw the probability maps
|
51 |
+
# if output_prob_maps is not None:
|
52 |
+
# output_prob_maps = np.concatenate(output_prob_maps, axis=1)
|
53 |
+
# ours_bboxes = np.concatenate([output_prob_maps, ours_bboxes], axis=1)
|
54 |
+
return ours_bboxes
|
55 |
+
|
56 |
+
|
57 |
+
def transform_image(image_pil):
|
58 |
+
transform = T.Compose(
|
59 |
+
[
|
60 |
+
T.RandomResize([800], max_size=1333),
|
61 |
+
T.ToTensor(),
|
62 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
63 |
+
]
|
64 |
+
)
|
65 |
+
image, _ = transform(image_pil, None) # 3, h, w
|
66 |
+
return image
|
67 |
+
|
68 |
+
|
69 |
+
def _load_model(model_config_path, model_checkpoint_path, device):
|
70 |
+
args = SLConfig.fromfile(model_config_path)
|
71 |
+
args.device = device
|
72 |
+
model = build_model(args)
|
73 |
+
model.load_state_dict(clean_state_dict(torch.load(model_checkpoint_path, map_location="cpu")["model"]), strict=False)
|
74 |
+
return model.to(device=device).eval()
|
75 |
+
|
76 |
+
|
77 |
+
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True):
|
78 |
+
caption = caption.lower()
|
79 |
+
caption = caption.strip()
|
80 |
+
if not caption.endswith("."):
|
81 |
+
caption = caption + "."
|
82 |
+
|
83 |
+
with torch.no_grad():
|
84 |
+
outputs = model(image[None], captions=[caption])
|
85 |
+
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
|
86 |
+
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
|
87 |
+
logits.shape[0]
|
88 |
+
|
89 |
+
# filter output
|
90 |
+
logits_filt = logits.clone()
|
91 |
+
boxes_filt = boxes.clone()
|
92 |
+
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
93 |
+
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
94 |
+
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
95 |
+
logits_filt.shape[0]
|
96 |
+
|
97 |
+
# get phrase
|
98 |
+
tokenlizer = model.tokenizer
|
99 |
+
tokenized = tokenlizer(caption)
|
100 |
+
# build pred
|
101 |
+
pred_phrases = []
|
102 |
+
scores = []
|
103 |
+
for logit, box in zip(logits_filt, boxes_filt):
|
104 |
+
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
|
105 |
+
if with_logits:
|
106 |
+
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
107 |
+
else:
|
108 |
+
pred_phrases.append(pred_phrase)
|
109 |
+
scores.append(logit.max().item())
|
110 |
+
return boxes_filt, torch.Tensor(scores), pred_phrases
|
111 |
+
|
112 |
+
|
113 |
+
def postprocess_masks(input_masks, input_pred_phrases):
|
114 |
+
input_masks_ = input_masks.cpu().numpy().transpose(0, 2, 3, 1).copy()
|
115 |
+
output_masks = input_masks.cpu().numpy().transpose(0, 2, 3, 1).copy()
|
116 |
+
for i in range(len(output_masks)):
|
117 |
+
for j in range(len(output_masks)):
|
118 |
+
if i == j:
|
119 |
+
continue
|
120 |
+
if ((input_masks_[i] * input_masks_[j]).sum() > 0) and (input_pred_phrases[i].split("(")[0] != input_pred_phrases[j].split("(")[0]):
|
121 |
+
# if two masks overlap and have different labels
|
122 |
+
if float(input_pred_phrases[i].split("(")[1].split(")")[0]) < float(input_pred_phrases[j].split("(")[1].split(")")[0]):
|
123 |
+
# if the score of the first mask is lower than the second mask, remove overlapping area from the first mask
|
124 |
+
output_masks[i] = np.logical_and(output_masks[i], np.logical_not(input_masks_[j]))
|
125 |
+
else:
|
126 |
+
# otherwise, remove overlapping area from the second mask
|
127 |
+
output_masks[j] = np.logical_and(output_masks[j], np.logical_not(input_masks_[i]))
|
128 |
+
return output_masks.transpose(3, 0, 1, 2)[0]
|
129 |
+
|
130 |
+
|
131 |
+
groundingdino_model = None
|
132 |
+
sam_predictor = None
|
133 |
+
already_converted = {}
|
134 |
+
config_file = os.path.join(pretrained_models_folder, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py")
|
135 |
+
|
136 |
+
|
137 |
+
def _find_higest_points(logits_map, num_top_points=20):
|
138 |
+
if num_top_points == 0:
|
139 |
+
return logits_map, []
|
140 |
+
# find the highest points on the logits map
|
141 |
+
gray = cv2.cvtColor(logits_map, cv2.COLOR_BGR2GRAY).astype("uint8")
|
142 |
+
# find the highest points
|
143 |
+
points = []
|
144 |
+
for i in range(num_top_points):
|
145 |
+
y, x = np.unravel_index(np.argmax(gray, axis=None), gray.shape)
|
146 |
+
points.append((x, y))
|
147 |
+
gray[y, x] = 0
|
148 |
+
# draw points
|
149 |
+
for point in points:
|
150 |
+
cv2.drawMarker(logits_map, point, (0, 0, 255), cv2.MARKER_CROSS, 10, 3)
|
151 |
+
return logits_map, points
|
152 |
+
|
153 |
+
|
154 |
+
def _find_contour_points(logits_map, num_points=5):
|
155 |
+
if num_points == 0:
|
156 |
+
return logits_map, []
|
157 |
+
# find contours and get number of points on the contour, then draw the points on the image
|
158 |
+
gray = cv2.cvtColor(logits_map, cv2.COLOR_BGR2GRAY).astype("uint8")
|
159 |
+
ret, thresh = cv2.threshold(gray, 155, 255, 0)
|
160 |
+
# erode to make the contour thinner
|
161 |
+
kernel = np.ones((13, 13), np.uint8)
|
162 |
+
# only apply erode when the image is large enough, otherwise, skip it
|
163 |
+
if np.sum(thresh) > (gray.shape[0] * gray.shape[1] * 255 * 0.1):
|
164 |
+
erode_iterations = int(np.log2(min(gray.shape[0], gray.shape[1])) - 1)
|
165 |
+
thresh = cv2.erode(thresh, kernel, iterations=erode_iterations)
|
166 |
+
|
167 |
+
# only use the largest contour
|
168 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
169 |
+
largest_contour = max(contours, key=cv2.contourArea)
|
170 |
+
|
171 |
+
points = []
|
172 |
+
if len(largest_contour) > num_points:
|
173 |
+
for i in range(0, len(largest_contour), len(largest_contour) // num_points):
|
174 |
+
if len(points) == num_points:
|
175 |
+
break
|
176 |
+
x, y = largest_contour[i][0]
|
177 |
+
points.append((x, y))
|
178 |
+
|
179 |
+
# make sure the points are at the same number as num_points
|
180 |
+
if len(points) == 0:
|
181 |
+
raise ValueError("no points found")
|
182 |
+
elif len(points) < num_points:
|
183 |
+
for i in range(num_points - len(points)):
|
184 |
+
points.append(points[-1])
|
185 |
+
elif len(points) > num_points:
|
186 |
+
points = points[:num_points]
|
187 |
+
else:
|
188 |
+
pass
|
189 |
+
# draw points
|
190 |
+
for point in points:
|
191 |
+
# cv2.circle(logits_map, point, 3, (0, 0, 255), -1)
|
192 |
+
cv2.drawMarker(logits_map, point, (0, 0, 255), cv2.MARKER_CROSS, 10, 3)
|
193 |
+
|
194 |
+
return logits_map, points
|
195 |
+
|
196 |
+
|
197 |
+
def _process_logits(logits, pred_phrases, top_n_points):
|
198 |
+
# print("logits", logits.shape)
|
199 |
+
# torch.Size([3, 1, 468, 500])
|
200 |
+
logits = logits.cpu().numpy()[:, 0, :, :]
|
201 |
+
logits = ((logits - np.min(logits)) / (np.max(logits) - np.min(logits))) * 255
|
202 |
+
logits_maps = []
|
203 |
+
points_list = []
|
204 |
+
for i, logits_map in enumerate(logits):
|
205 |
+
try:
|
206 |
+
logits_map = cv2.cvtColor(np.array(logits_map, dtype=np.uint8), cv2.COLOR_GRAY2BGR)
|
207 |
+
logits_map, points = _find_higest_points(logits_map, num_top_points=top_n_points)
|
208 |
+
if len(points) == 0:
|
209 |
+
points = None
|
210 |
+
cv2.putText(logits_map, pred_phrases[i], (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
211 |
+
logits_maps.append(logits_map)
|
212 |
+
points_list.append(points)
|
213 |
+
except Exception as e:
|
214 |
+
print("error in _process_logits", e)
|
215 |
+
continue
|
216 |
+
return logits_maps, points_list
|
217 |
+
|
218 |
+
|
219 |
+
def run_grounded_sam(
|
220 |
+
input_image,
|
221 |
+
text_prompt,
|
222 |
+
box_threshold,
|
223 |
+
text_threshold,
|
224 |
+
iou_threshold,
|
225 |
+
LABELS = [],
|
226 |
+
IDS = [],
|
227 |
+
llm = None,
|
228 |
+
timer = None,
|
229 |
+
# for ablation study
|
230 |
+
wo_post = False,
|
231 |
+
top_n_points = 20,
|
232 |
+
):
|
233 |
+
global groundingdino_model, sam_predictor, already_converted
|
234 |
+
|
235 |
+
# load image
|
236 |
+
image_pil = input_image["image"].convert("RGB")
|
237 |
+
transformed_image = transform_image(image_pil).to(device=device)
|
238 |
+
size = image_pil.size
|
239 |
+
|
240 |
+
if groundingdino_model is None:
|
241 |
+
groundingdino_model = _load_model(config_file, groundingdino_ckpt, device=device)
|
242 |
+
|
243 |
+
# run grounding dino model
|
244 |
+
boxes_filt, scores, pred_phrases = get_grounding_output(groundingdino_model, transformed_image, text_prompt, box_threshold, text_threshold)
|
245 |
+
timer.check("get_grounding_output")
|
246 |
+
|
247 |
+
# process boxes
|
248 |
+
H, W = size[1], size[0]
|
249 |
+
for i in range(boxes_filt.size(0)):
|
250 |
+
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
251 |
+
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
252 |
+
boxes_filt[i][2:] += boxes_filt[i][:2]
|
253 |
+
boxes_filt = boxes_filt.cpu()
|
254 |
+
|
255 |
+
# nms
|
256 |
+
nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
|
257 |
+
boxes_filt = boxes_filt[nms_idx]
|
258 |
+
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
|
259 |
+
|
260 |
+
if sam_predictor is None:
|
261 |
+
# initialize SAM
|
262 |
+
assert sam_ckpt, "sam_ckpt is not found!"
|
263 |
+
sam = build_sam_vit_l(checkpoint=sam_ckpt)
|
264 |
+
sam.to(device=device).eval()
|
265 |
+
sam_predictor = SamPredictor(sam)
|
266 |
+
sam_predictor.model.to(device=device)
|
267 |
+
image = np.array(image_pil)
|
268 |
+
sam_predictor.set_image(image)
|
269 |
+
|
270 |
+
input_box = torch.tensor(boxes_filt, device=device)
|
271 |
+
transformed_boxes = sam_predictor.transform.apply_boxes_torch(input_box, image.shape[:2])
|
272 |
+
logits, _, _ = sam_predictor.predict_torch(
|
273 |
+
point_coords = None,
|
274 |
+
point_labels = None,
|
275 |
+
boxes = transformed_boxes,
|
276 |
+
multimask_output = False,
|
277 |
+
return_logits = True,
|
278 |
+
hq_token_only = False,
|
279 |
+
)
|
280 |
+
timer.check("get prob")
|
281 |
+
|
282 |
+
output_prob_maps, output_points = _process_logits(logits, pred_phrases, top_n_points=top_n_points)
|
283 |
+
if top_n_points == 0:
|
284 |
+
# processing without points prompt, for ablation study
|
285 |
+
print("processing without points prompt, for ablation study")
|
286 |
+
point_coords = None
|
287 |
+
point_labels = None
|
288 |
+
else:
|
289 |
+
if None in output_points:
|
290 |
+
point_coords = None
|
291 |
+
point_labels = None
|
292 |
+
else:
|
293 |
+
point_coords = torch.tensor(np.array(output_points), device=device)
|
294 |
+
point_coords = sam_predictor.transform.apply_coords_torch(point_coords, image.shape[:2])
|
295 |
+
point_labels = torch.ones(point_coords.shape[:2], device=device)
|
296 |
+
# print("point_coords", point_coords.shape, point_labels.shape, transformed_boxes.shape)
|
297 |
+
transformed_boxes = transformed_boxes[: point_coords.shape[0]]
|
298 |
+
|
299 |
+
masks, _, _ = sam_predictor.predict_torch(
|
300 |
+
point_coords = point_coords,
|
301 |
+
point_labels = point_labels,
|
302 |
+
boxes = transformed_boxes,
|
303 |
+
multimask_output = False,
|
304 |
+
hq_token_only = False,
|
305 |
+
)
|
306 |
+
masks = postprocess_masks(masks, pred_phrases)
|
307 |
+
timer.check("postprocess_masks")
|
308 |
+
|
309 |
+
label_image = Image.new("L", size, color=0)
|
310 |
+
label_draw = np.array(label_image)
|
311 |
+
output_labels = []
|
312 |
+
for mask, pred_phrase in zip(masks, pred_phrases):
|
313 |
+
try:
|
314 |
+
label = pred_phrase.split("(")[0]
|
315 |
+
if label in ["", " "]:
|
316 |
+
# skip empty label
|
317 |
+
continue
|
318 |
+
elif label in LABELS:
|
319 |
+
# no need to convert if it's one of the target labels
|
320 |
+
post_label = label
|
321 |
+
elif label in already_converted:
|
322 |
+
# check if the label was converted before to save time and model calls
|
323 |
+
post_label = already_converted[label]
|
324 |
+
print("already converted: {} to {}".format(label, already_converted[label]))
|
325 |
+
else:
|
326 |
+
# convert the label using llm model
|
327 |
+
label = label.replace(" ", "") if "-" in label else label
|
328 |
+
if wo_post:
|
329 |
+
print("wo_post is True, for ablation study")
|
330 |
+
# skip post refinement, for ablation study
|
331 |
+
post_label = label
|
332 |
+
else:
|
333 |
+
post_label = post_refinement(LABELS, label, llm=llm)
|
334 |
+
print("convert from {} to {}".format(label, post_label))
|
335 |
+
# add to the already_converted list, no matter it's in the list or not to save $!
|
336 |
+
already_converted.update({label: post_label})
|
337 |
+
if post_label not in LABELS:
|
338 |
+
raise ValueError("label not found, {} from {}".format(post_label, label))
|
339 |
+
output_labels.append(post_label)
|
340 |
+
label_index = LABELS.index(post_label)
|
341 |
+
label_draw[mask] = IDS[label_index]
|
342 |
+
except ValueError as e:
|
343 |
+
print("e", e)
|
344 |
+
print("label not found: ", pred_phrase)
|
345 |
+
traceback.print_exc()
|
346 |
+
continue
|
347 |
+
timer.check("llm+draw label")
|
348 |
+
return label_draw, boxes_filt, output_labels, output_prob_maps, output_points
|
utils/labels_utils.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
COCO_CATEGORIES = [
|
5 |
+
# borrowed from detectron2
|
6 |
+
# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/datasets/register_coco_stuff_10k.py
|
7 |
+
{"color": [0, 0, 0], "isthing": 0, "id": 0, "name": "unlabeled"},
|
8 |
+
{"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
|
9 |
+
{"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
|
10 |
+
{"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
|
11 |
+
{"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
|
12 |
+
{"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
|
13 |
+
{"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
|
14 |
+
{"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
|
15 |
+
{"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
|
16 |
+
{"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
|
17 |
+
{"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
|
18 |
+
{"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
|
19 |
+
{"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
|
20 |
+
{"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
|
21 |
+
{"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
|
22 |
+
{"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
|
23 |
+
{"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
|
24 |
+
{"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
|
25 |
+
{"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
|
26 |
+
{"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
|
27 |
+
{"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
|
28 |
+
{"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
|
29 |
+
{"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
|
30 |
+
{"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
|
31 |
+
{"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
|
32 |
+
{"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
|
33 |
+
{"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
|
34 |
+
{"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
|
35 |
+
{"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
|
36 |
+
{"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
|
37 |
+
{"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
|
38 |
+
{"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
|
39 |
+
{"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
|
40 |
+
{"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
|
41 |
+
{"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
|
42 |
+
{"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
|
43 |
+
{"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
|
44 |
+
{"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
|
45 |
+
{"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
|
46 |
+
{"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
|
47 |
+
{"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
|
48 |
+
{"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
|
49 |
+
{"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
|
50 |
+
{"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
|
51 |
+
{"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
|
52 |
+
{"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
|
53 |
+
{"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
|
54 |
+
{"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
|
55 |
+
{"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
|
56 |
+
{"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
|
57 |
+
{"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
|
58 |
+
{"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
|
59 |
+
{"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
|
60 |
+
{"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
|
61 |
+
{"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
|
62 |
+
{"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
|
63 |
+
{"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
|
64 |
+
{"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
|
65 |
+
{"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
|
66 |
+
{"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
|
67 |
+
{"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
|
68 |
+
{"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
|
69 |
+
{"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
|
70 |
+
{"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
|
71 |
+
{"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
|
72 |
+
{"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
|
73 |
+
{"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
|
74 |
+
{"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
|
75 |
+
{"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
|
76 |
+
{"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
|
77 |
+
{"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
|
78 |
+
{"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
|
79 |
+
{"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
|
80 |
+
{"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
|
81 |
+
{"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
|
82 |
+
{"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
|
83 |
+
{"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
|
84 |
+
{"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
|
85 |
+
{"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
|
86 |
+
{"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
|
87 |
+
{"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
|
88 |
+
]
|
89 |
+
|
90 |
+
|
91 |
+
def create_coco_colormap(IDs):
|
92 |
+
all_colors = []
|
93 |
+
vis_colors = [category["color"] for category in COCO_CATEGORIES]
|
94 |
+
used_ids = [category["id"] for category in COCO_CATEGORIES]
|
95 |
+
all_colors = [vis_colors[used_ids.index(id)] if id in used_ids else [0, 0, 0] for id in range(max(IDs)+1)]
|
96 |
+
return np.array(all_colors, dtype=int)
|
97 |
+
|
98 |
+
|
99 |
+
def create_cityscapes_colormap(IDs):
|
100 |
+
vis_colors = [
|
101 |
+
(0, 0, 0),
|
102 |
+
(128, 64, 128),
|
103 |
+
(244, 35, 232),
|
104 |
+
(70, 70, 70),
|
105 |
+
(102, 102, 156),
|
106 |
+
(190, 153, 153),
|
107 |
+
(153, 153, 153),
|
108 |
+
(250, 170, 30),
|
109 |
+
(220, 220, 0),
|
110 |
+
(107, 142, 35),
|
111 |
+
(152, 251, 152),
|
112 |
+
(70, 130, 180),
|
113 |
+
(220, 20, 60),
|
114 |
+
(255, 0, 0),
|
115 |
+
(0, 0, 142),
|
116 |
+
(0, 0, 70),
|
117 |
+
(0, 60, 100),
|
118 |
+
(0, 80, 100),
|
119 |
+
(0, 0, 230),
|
120 |
+
(119, 11, 32),
|
121 |
+
]
|
122 |
+
|
123 |
+
all_colors = [vis_colors[IDs.index(id)] if id in IDs else [0, 0, 0] for id in range(max(IDs)+1)]
|
124 |
+
return np.array(all_colors, dtype=int)
|
125 |
+
|
126 |
+
def create_pascal_label_colormap(n_labels=256):
|
127 |
+
def bitget(byteval, idx):
|
128 |
+
return ((byteval & (1 << idx)) != 0)
|
129 |
+
|
130 |
+
cmap = np.zeros((n_labels, 3), dtype=np.uint8)
|
131 |
+
for i in range(n_labels):
|
132 |
+
r = g = b = 0
|
133 |
+
c = i
|
134 |
+
for j in range(8):
|
135 |
+
r = r | (bitget(c, 0) << 7-j)
|
136 |
+
g = g | (bitget(c, 1) << 7-j)
|
137 |
+
b = b | (bitget(c, 2) << 7-j)
|
138 |
+
c = c >> 3
|
139 |
+
cmap[i] = np.array([r, g, b])
|
140 |
+
return cmap
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
class Labels:
|
145 |
+
def __init__(self, config=None):
|
146 |
+
max_label_num = 200
|
147 |
+
if config is not None:
|
148 |
+
self.LABELS = config.label_list.split(", ")
|
149 |
+
self.IDS = config.mask_ids if hasattr(config, "mask_ids") else [i for i in range(len(self.LABELS))]
|
150 |
+
print("self.IDS", self.IDS)
|
151 |
+
if len(self.LABELS) > max_label_num:
|
152 |
+
raise ValueError(f"Too many labels! The maximum number of labels is {max_label_num}.")
|
153 |
+
else:
|
154 |
+
raise NotImplementedError("config is None")
|
155 |
+
|
156 |
+
if "COCO" in config.Name:
|
157 |
+
self.COLORS = create_coco_colormap(self.IDS)
|
158 |
+
elif "City" in config.Name:
|
159 |
+
self.COLORS = create_cityscapes_colormap(self.IDS)
|
160 |
+
else:
|
161 |
+
# default to pascal label colormap
|
162 |
+
self.COLORS = create_pascal_label_colormap()
|
163 |
+
|
164 |
+
assert len(self.COLORS) >= len(self.LABELS), f"len(self.COLORS)={len(self.COLORS)} < len(self.LABELS)={len(self.LABELS)}"
|
165 |
+
|
166 |
+
def check_labels(self, labels_list):
|
167 |
+
output_labels_list = []
|
168 |
+
for labels in labels_list:
|
169 |
+
output_labels = []
|
170 |
+
labels = labels.split(", ")
|
171 |
+
for label in labels:
|
172 |
+
if label == "background":
|
173 |
+
# skip the background label
|
174 |
+
continue
|
175 |
+
if label in self.LABELS:
|
176 |
+
output_labels.append(label)
|
177 |
+
output_labels = list(set(output_labels))
|
178 |
+
output_labels_list.append(", ".join(output_labels))
|
179 |
+
return output_labels_list
|
180 |
+
|
181 |
+
def draw_mask(self, label_ori, image_ori, print_label=False, tag="", only_label=False):
|
182 |
+
label_ori = label_ori.astype(np.uint8)
|
183 |
+
label = np.zeros_like(image_ori, dtype=np.uint8)
|
184 |
+
# print("{}: {}".format(tag, np.unique(label_ori)))
|
185 |
+
for id in np.unique(label_ori):
|
186 |
+
# print("id", id)
|
187 |
+
if id == 0 or id == 255:
|
188 |
+
continue
|
189 |
+
elif id not in self.IDS:
|
190 |
+
print(f"Label {id} is not in the label list.")
|
191 |
+
continue
|
192 |
+
i = self.IDS.index(id)
|
193 |
+
center = np.mean(np.argwhere(label_ori == id), axis=0).astype(np.int64)
|
194 |
+
label[label_ori == id] = self.COLORS[id]
|
195 |
+
if print_label:
|
196 |
+
# add text in the center of the mask
|
197 |
+
cv2.putText(label, self.LABELS[i], (center[1], center[0]), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
198 |
+
# print(i, self.LABELS[i])
|
199 |
+
# RGB to BGR
|
200 |
+
label = cv2.cvtColor(label, cv2.COLOR_RGB2BGR)
|
201 |
+
return cv2.addWeighted(label, 0.6, image_ori, 0.4, 0) if not only_label else label
|
202 |
+
|
203 |
+
def find_gt_labels(self, label_gt):
|
204 |
+
label_gt = label_gt.astype(np.uint8)
|
205 |
+
label_gt_list = []
|
206 |
+
for id in np.unique(label_gt):
|
207 |
+
if id == 0 or id == 255:
|
208 |
+
continue
|
209 |
+
elif id not in self.IDS:
|
210 |
+
print(f"Label {id} is not in the label list.")
|
211 |
+
continue
|
212 |
+
i = self.IDS.index(id)
|
213 |
+
label_gt_list.append(self.LABELS[i])
|
214 |
+
return label_gt_list
|
utils/llms_utils.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from openai import OpenAI
|
5 |
+
from termcolor import colored
|
6 |
+
|
7 |
+
import transformers
|
8 |
+
# from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
|
9 |
+
from huggingface_hub import login
|
10 |
+
|
11 |
+
# environment variables and paths
|
12 |
+
from .env_utils import get_device, low_vram_mode
|
13 |
+
|
14 |
+
device = get_device()
|
15 |
+
|
16 |
+
class GPT:
|
17 |
+
def __init__(self, model="gpt-4o-mini", api_key=None):
|
18 |
+
self.prices = {
|
19 |
+
# check at https://openai.com/api/pricing/
|
20 |
+
"gpt-3.5-turbo-0125": [0.0000005, 0.0000015],
|
21 |
+
"gpt-4o-mini" : [0.00000015, 0.00000060],
|
22 |
+
"gpt-4-1106-preview": [0.00001, 0.00003],
|
23 |
+
"gpt-4-0125-preview": [0.00001, 0.00003],
|
24 |
+
"gpt-4-turbo" : [0.00001, 0.00003],
|
25 |
+
"gpt-4o" : [0.000005, 0.000015],
|
26 |
+
}
|
27 |
+
self.cheaper_model = "gpt-4o-mini"
|
28 |
+
assert model in self.prices.keys(), "Invalid model, please choose from: {}, or add new models in the code.".format(self.prices.keys())
|
29 |
+
self.model = model
|
30 |
+
print(f"Using {model}")
|
31 |
+
self.client = OpenAI(api_key=api_key)
|
32 |
+
self.total_cost = 0.0
|
33 |
+
|
34 |
+
def _update(self, response, price):
|
35 |
+
current_cost = response.usage.completion_tokens * price[0] + response.usage.prompt_tokens * price[1]
|
36 |
+
self.total_cost += current_cost
|
37 |
+
# print in 4 decimal places
|
38 |
+
print(
|
39 |
+
colored(
|
40 |
+
f"Current Tokens: {response.usage.completion_tokens + response.usage.prompt_tokens:d} \
|
41 |
+
Current cost: {current_cost:.4f} $, \
|
42 |
+
Total cost: {self.total_cost:.4f} $",
|
43 |
+
"yellow",
|
44 |
+
)
|
45 |
+
)
|
46 |
+
|
47 |
+
def chat(self, messages, temperature=0.0, max_tokens=200, post=False):
|
48 |
+
# set temperature to 0.0 for more deterministic results
|
49 |
+
if post:
|
50 |
+
# use cheaper model for post-refinement to save costs, since the task is simpler.
|
51 |
+
generated_text = self.client.chat.completions.create(
|
52 |
+
model=self.cheaper_model, messages=messages, temperature=temperature, max_tokens=max_tokens
|
53 |
+
)
|
54 |
+
self._update(generated_text, self.prices[self.cheaper_model])
|
55 |
+
else:
|
56 |
+
generated_text = self.client.chat.completions.create(
|
57 |
+
model=self.model, messages=messages, temperature=temperature, max_tokens=max_tokens
|
58 |
+
)
|
59 |
+
self._update(generated_text, self.prices[self.model])
|
60 |
+
generated_text = generated_text.choices[0].message.content
|
61 |
+
return generated_text
|
62 |
+
|
63 |
+
|
64 |
+
class Llama3:
|
65 |
+
def __init__(self, model="Meta-Llama-3-8B-Instruct"):
|
66 |
+
login(token=os.getenv('HF_TOKEN'))
|
67 |
+
model = "meta-llama/{}".format(model) # or replace with your local model path
|
68 |
+
print(f"Using {model}")
|
69 |
+
# ZeroGPU does not support quantization.
|
70 |
+
# tokenizer = AutoTokenizer.from_pretrained(model)
|
71 |
+
# if low_vram_mode:
|
72 |
+
# model = AutoModelForCausalLM.from_pretrained(
|
73 |
+
# model, quantization_config=BitsAndBytesConfig(load_in_8bit=True), device_map="auto"
|
74 |
+
# ).eval()
|
75 |
+
self.pipeline = transformers.pipeline(
|
76 |
+
"text-generation",
|
77 |
+
model = model,
|
78 |
+
# tokenizer = tokenizer,
|
79 |
+
model_kwargs = {"torch_dtype": torch.bfloat16},
|
80 |
+
device_map = "auto",
|
81 |
+
)
|
82 |
+
self.terminators = [self.pipeline.tokenizer.eos_token_id, self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
|
83 |
+
|
84 |
+
def _update(self):
|
85 |
+
print(colored("Using Llama-3, Free", "green"))
|
86 |
+
|
87 |
+
def chat(self, messages, temperature=0.0, max_tokens=200, post=False):
|
88 |
+
prompt = self.pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
89 |
+
generated_text = self.pipeline(
|
90 |
+
prompt,
|
91 |
+
max_new_tokens = max_tokens,
|
92 |
+
eos_token_id = self.terminators,
|
93 |
+
pad_token_id = 128001,
|
94 |
+
do_sample = True,
|
95 |
+
temperature = max(temperature, 0.01), # 0.0 is not supported
|
96 |
+
top_p = 0.9,
|
97 |
+
)
|
98 |
+
self._update()
|
99 |
+
generated_text = generated_text[0]["generated_text"][len(prompt) :]
|
100 |
+
return generated_text
|
101 |
+
|
102 |
+
|
103 |
+
# Define the timeout handler
|
104 |
+
def timeout_handler(signum, frame):
|
105 |
+
raise TimeoutError()
|
106 |
+
|
107 |
+
|
108 |
+
def init_model(model, api_key=None):
|
109 |
+
if "gpt" in model:
|
110 |
+
return GPT(model=model, api_key=api_key)
|
111 |
+
elif "Llama" in model:
|
112 |
+
return Llama3(model=model)
|
113 |
+
else:
|
114 |
+
raise ValueError("Invalid model")
|
115 |
+
|
116 |
+
|
117 |
+
def _generate_example_prompt(examples, llm=None):
|
118 |
+
# system prompt
|
119 |
+
system_prompt = """
|
120 |
+
Task Description:
|
121 |
+
- you will provide detailed explanations for example inputs and outputs within the context of the task.
|
122 |
+
|
123 |
+
Please adhere to the following rules:
|
124 |
+
- Exclude terms that appear in both lists.
|
125 |
+
- Detail the relevance of unmatched terms from input to output, focusing on indirect relationships.
|
126 |
+
- Identify and explain terms common to all output lists but rarely present in input lists; include these at the end of the output labeled 'Recommend Include Labels'.
|
127 |
+
- Each explanation should be concise, around 50 words.
|
128 |
+
|
129 |
+
Output Format:
|
130 |
+
- '1. Input... Output... Explanation... n. Input... Output... Explanation... \n Recommend Include Labels: label1, labeln, ...'
|
131 |
+
"""
|
132 |
+
messages = [
|
133 |
+
{"role": "system", "content": system_prompt},
|
134 |
+
{
|
135 |
+
"role": "user",
|
136 |
+
"content": f"Here are the input and output lists for which you need to provide detailed explanations:{examples.strip()}",
|
137 |
+
},
|
138 |
+
]
|
139 |
+
generated_example = llm.chat(messages, temperature=0.0, max_tokens=1000)
|
140 |
+
return generated_example
|
141 |
+
|
142 |
+
|
143 |
+
def _make_prompt(label_list, example=None):
|
144 |
+
Cityscape = "sidewalk" in label_list
|
145 |
+
if Cityscape:
|
146 |
+
add_text = f'contain at least {len(label_list.split(", "))} labels, '
|
147 |
+
else:
|
148 |
+
add_text = ""
|
149 |
+
# Task description and instructions for processing the input to generate output
|
150 |
+
system_prompt = f"""
|
151 |
+
Task Description:
|
152 |
+
- You will receive a list of caption tags accompanied by a caption text and must assign appropriate labels from a predefined label list: "{label_list}".
|
153 |
+
|
154 |
+
Instructions:
|
155 |
+
Step 1. Visualize the scene suggested by the input caption tags and text.
|
156 |
+
Step 2. Analyze each term within the overall scene to predict relevant labels from the predefined list, ensuring no term is overlooked.
|
157 |
+
Step 3. Now forget the input list and focus on the scene as a whole, expanding upon the labels to include any contextually relevant labels that complete the scene or setting.
|
158 |
+
Step 4. Compile all identified labels into a comma-separated list, adhering strictly to the specified format.
|
159 |
+
|
160 |
+
Contextually Relevant Tips:
|
161 |
+
- Equivalencies include converting "girl, man" to "person" and "flower, vase" to "potted plant", while "bicycle, motorcycle" suggest "rider".
|
162 |
+
- An outdoor scene may include labels like "sky", "tree", "clouds", "terrain".
|
163 |
+
- An urban scene may imply "bus", "bicycle", "road", "sidewalk", "building", "pole", "traffic-light", "traffic-sign".
|
164 |
+
|
165 |
+
Output:
|
166 |
+
- Do not output any explanations other than the final label list.
|
167 |
+
- The final output should {add_text}strictly adhere to the specified format: label1, label2, ... labeln
|
168 |
+
""".strip()
|
169 |
+
if example:
|
170 |
+
system_prompt += f"""
|
171 |
+
Additional Examples with Detailed Explanations:
|
172 |
+
{example}
|
173 |
+
"""
|
174 |
+
print("system_prompt: ", system_prompt)
|
175 |
+
return system_prompt
|
176 |
+
|
177 |
+
# - You will receive a list of terms accompanied by a caption text and must assign appropriate labels from a predefined label list: "{label_list}".
|
178 |
+
|
179 |
+
# Instructions:
|
180 |
+
# Step 1. Visualize the scene suggested by the input list and caption text.
|
181 |
+
|
182 |
+
|
183 |
+
def make_prompt(label_list):
|
184 |
+
# Create a new system prompt using the label list and the improved example prompt
|
185 |
+
system_prompt = _make_prompt(label_list)
|
186 |
+
system_prompt = {"role": "system", "content": system_prompt.strip()}
|
187 |
+
print("system_prompt: ", system_prompt)
|
188 |
+
return system_prompt
|
189 |
+
|
190 |
+
|
191 |
+
def _call_llm(system_prompt, llm, user_input):
|
192 |
+
messages = [system_prompt, {"role": "user", "content": "Here are input caption tags and text: " + user_input}]
|
193 |
+
converted_label = llm.chat(messages=messages, temperature=0.0, max_tokens=200)
|
194 |
+
return converted_label
|
195 |
+
|
196 |
+
|
197 |
+
def pre_refinement(user_input_list, system_prompt, llm=None):
|
198 |
+
llm_outputs = [_call_llm(system_prompt, llm, user_input) for user_input in user_input_list]
|
199 |
+
converted_labels = [f"{user_input_}, {converted_label}" for user_input_, converted_label in zip(user_input_list, llm_outputs)]
|
200 |
+
return converted_labels, llm_outputs
|
201 |
+
|
202 |
+
|
203 |
+
def post_refinement(label_list, detected_label, llm=None):
|
204 |
+
system_input = f"""
|
205 |
+
Task Description:
|
206 |
+
- You will receive a specific phrase and must assign an appropriate label from the predefined label list: "{label_list}". \n \
|
207 |
+
|
208 |
+
Please adhere to the following rules: \n \
|
209 |
+
- Select and return only one relevant label from the predefined label list that corresponds to the given phrase. \n \
|
210 |
+
- Do not include any additional information or context beyond the label itself. \n \
|
211 |
+
- Format is purely the label itself, without any additional punctuation or formatting. \n \
|
212 |
+
"""
|
213 |
+
system_input = {"role": "system", "content": system_input}
|
214 |
+
messages = [system_input, {"role": "user", "content": detected_label}]
|
215 |
+
if detected_label == "":
|
216 |
+
return ""
|
217 |
+
generated_label = None
|
218 |
+
for count in range(3):
|
219 |
+
generated_label = llm.chat(messages=messages, temperature=0.0 if count == 0 else 0.1 * (count), post=True)
|
220 |
+
if generated_label != "":
|
221 |
+
break
|
222 |
+
return generated_label
|
223 |
+
|
224 |
+
|
225 |
+
if __name__ == "__main__":
|
226 |
+
# test the functions
|
227 |
+
llm = Llama3(model="Meta-Llama-3-8B-Instruct")
|
228 |
+
|
229 |
+
system_prompt = make_prompt("person, car, tree, sky, road, building, sidewalk, traffic-light, traffic-sign", llm=llm)
|
230 |
+
|
231 |
+
converted_labels, llm_outputs = pre_refinement(["person, car, road, traffic-light"], system_prompt, llm=llm)
|
232 |
+
print("converted_labels: ", converted_labels)
|
233 |
+
print("llm_outputs: ", llm_outputs)
|
utils/ram_utils.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from .env_utils import get_device, low_vram_mode
|
5 |
+
|
6 |
+
device = get_device()
|
7 |
+
|
8 |
+
pretrained_models_folder = os.path.join(os.path.dirname(__file__), "../pretrained-models")
|
9 |
+
|
10 |
+
|
11 |
+
# RAM++
|
12 |
+
from ram.models import ram_plus
|
13 |
+
from ram import get_transform, inference_ram
|
14 |
+
|
15 |
+
ram_ckpt = os.path.join(pretrained_models_folder, "checkpoints/ram_plus_swin_large_14m.pth")
|
16 |
+
ram_precision = torch.bfloat16
|
17 |
+
|
18 |
+
|
19 |
+
def ram_init():
|
20 |
+
image_size = 384
|
21 |
+
transform = get_transform(image_size=image_size)
|
22 |
+
#######load model#######
|
23 |
+
model = ram_plus(pretrained=ram_ckpt, image_size=image_size, vit="swin_l")
|
24 |
+
model = model.to(device=device, dtype=ram_precision)
|
25 |
+
model.eval()
|
26 |
+
print("RAM++ model loaded")
|
27 |
+
return model, transform
|
28 |
+
|
29 |
+
|
30 |
+
# Initialize the model when importing the module
|
31 |
+
ram_model, ram_transform = ram_init()
|
32 |
+
|
33 |
+
|
34 |
+
def _inference(image_pil):
|
35 |
+
image = ram_transform(image_pil).unsqueeze(0)
|
36 |
+
image = image.to(device=device, dtype=ram_precision)
|
37 |
+
res = inference_ram(image, ram_model)
|
38 |
+
result = res[0].replace(" | ", ", ")
|
39 |
+
return result
|
40 |
+
|
41 |
+
|
42 |
+
def _split_large_image(image_pil):
|
43 |
+
size = image_pil.size
|
44 |
+
print("Image size is too large, split into smaller patches")
|
45 |
+
# Split the image into 4 patches
|
46 |
+
patches = []
|
47 |
+
patch_size = (size[0] // 2, size[1] // 2)
|
48 |
+
for i in range(2):
|
49 |
+
for j in range(2):
|
50 |
+
left = i * patch_size[0]
|
51 |
+
top = j * patch_size[1]
|
52 |
+
right = left + patch_size[0]
|
53 |
+
bottom = top + patch_size[1]
|
54 |
+
patch = image_pil.crop((left, top, right, bottom))
|
55 |
+
patches.append(patch)
|
56 |
+
return patches
|
57 |
+
|
58 |
+
|
59 |
+
def ram_inference(image_pil: Image.Image):
|
60 |
+
size = image_pil.size
|
61 |
+
if size[0] > 640 or size[1] > 640:
|
62 |
+
# split only once in the online demo version.
|
63 |
+
patches = _split_large_image(image_pil)
|
64 |
+
# while any(patch.size[0] > 640 or patch.size[1] > 640 for patch in patches):
|
65 |
+
# patches = [_split_large_image(patch) for patch in patches]
|
66 |
+
# patches = [patch for sublist in patches for patch in sublist]
|
67 |
+
# Inference on each patch
|
68 |
+
results = []
|
69 |
+
for patch in patches:
|
70 |
+
result = _inference(patch)
|
71 |
+
results.extend(result.split(", "))
|
72 |
+
results = list(set(results))
|
73 |
+
# Combine the results
|
74 |
+
final_result = ", ".join(results)
|
75 |
+
return final_result
|
76 |
+
else:
|
77 |
+
print("Image size is small enough for inference")
|
78 |
+
return _inference(image_pil)
|
79 |
+
|
80 |
+
|
81 |
+
if __name__ == "__main__":
|
82 |
+
# Test the RAM++ model
|
83 |
+
image_path = os.path.join(os.path.dirname(__file__), "../sources/test_imgs/1.jpg")
|
84 |
+
image = Image.open(image_path)
|
85 |
+
result = ram_inference(image)
|
86 |
+
print(result)
|
utils/timer_utils.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import logging
|
4 |
+
|
5 |
+
|
6 |
+
def create_logger(logger_name: str, log_file_path: os.PathLike = None):
|
7 |
+
"""
|
8 |
+
Create a logger with the specified name and log file path.
|
9 |
+
"""
|
10 |
+
logger = logging.getLogger(logger_name)
|
11 |
+
logger.propagate = False
|
12 |
+
logger.setLevel(logging.DEBUG)
|
13 |
+
assert log_file_path is not None, "log_file_path is required"
|
14 |
+
fh = logging.FileHandler(log_file_path)
|
15 |
+
fh_formatter = logging.Formatter("%(asctime)s : %(levelname)s, %(funcName)s Message: %(message)s")
|
16 |
+
fh.setFormatter(fh_formatter)
|
17 |
+
logger.addHandler(fh)
|
18 |
+
logger.info(f"logging start: {logger_name}")
|
19 |
+
return logger
|
20 |
+
|
21 |
+
|
22 |
+
class Timer:
|
23 |
+
"""
|
24 |
+
A simple timer class for measuring elapsed time.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, filename: os.PathLike = "timer_log.log", reset: bool = False):
|
28 |
+
"""
|
29 |
+
Initialize the Timer object.
|
30 |
+
"""
|
31 |
+
self.start_time = None
|
32 |
+
self.last_checkpoint = None
|
33 |
+
self.filename = filename
|
34 |
+
self.logger = create_logger("Timer", filename)
|
35 |
+
if reset:
|
36 |
+
self._reset_log_file()
|
37 |
+
|
38 |
+
def _reset_log_file(self):
|
39 |
+
"""
|
40 |
+
Reset the log file by clearing its contents.
|
41 |
+
"""
|
42 |
+
with open(self.filename, "w") as file:
|
43 |
+
file.write("")
|
44 |
+
|
45 |
+
def start(self):
|
46 |
+
"""
|
47 |
+
Start the timer.
|
48 |
+
"""
|
49 |
+
self.start_time = time.time()
|
50 |
+
self.last_checkpoint = self.start_time
|
51 |
+
self.logger.info("Timer started.")
|
52 |
+
|
53 |
+
def check(self, message):
|
54 |
+
"""
|
55 |
+
Log a checkpoint with the current time and time since the last checkpoint.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
message (str): The message to include in the log.
|
59 |
+
"""
|
60 |
+
if self.start_time is None:
|
61 |
+
self.logger.warning("Timer has not been started.")
|
62 |
+
else:
|
63 |
+
log_message = (
|
64 |
+
f"Current time count: {time.time() - self.start_time:.4f} seconds, "
|
65 |
+
f"Time since last checkpoint: {time.time() - self.last_checkpoint:.4f} seconds, "
|
66 |
+
f"for {message}"
|
67 |
+
)
|
68 |
+
self.last_checkpoint = time.time()
|
69 |
+
self.logger.info(log_message)
|
70 |
+
|
71 |
+
def stop(self):
|
72 |
+
"""
|
73 |
+
Stop the timer and log the elapsed time.
|
74 |
+
"""
|
75 |
+
if self.start_time is None:
|
76 |
+
self.logger.warning("Timer has not been started.")
|
77 |
+
else:
|
78 |
+
self.end_time = time.time()
|
79 |
+
self.logger.info(f"Total elapsed time: {self.end_time - self.start_time} seconds\n")
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
# Test the Timer class
|
84 |
+
timer = Timer(filename="timer_log.log", reset=True)
|
85 |
+
timer.start()
|
86 |
+
timer.check("First checkpoint")
|
87 |
+
time.sleep(1)
|
88 |
+
timer.check("Second checkpoint")
|
89 |
+
timer.stop()
|