sky24h commited on
Commit
a9d25c7
·
0 Parent(s):

gradio demo for ZeroGPU, HF

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+
30
+ *.pth
31
+ *.bin
32
+ *.log
33
+ *.safetensors
34
+ outputs/
35
+ outputs_single/
36
+ results/
37
+ pretrained-models/checkpoints/
38
+
39
+
40
+ # PyInstaller
41
+ # Usually these files are written by a python script from a template
42
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
43
+ *.manifest
44
+ *.spec
45
+
46
+ # Installer logs
47
+ pip-log.txt
48
+ pip-delete-this-directory.txt
49
+
50
+ # Unit test / coverage reports
51
+ htmlcov/
52
+ .tox/
53
+ .nox/
54
+ .coverage
55
+ .coverage.*
56
+ .cache
57
+ nosetests.xml
58
+ coverage.xml
59
+ *.cover
60
+ *.py,cover
61
+ .hypothesis/
62
+ .pytest_cache/
63
+ cover/
64
+
65
+ # Translations
66
+ *.mo
67
+ *.pot
68
+
69
+ # Django stuff:
70
+ *.log
71
+ local_settings.py
72
+ db.sqlite3
73
+ db.sqlite3-journal
74
+
75
+ # Flask stuff:
76
+ instance/
77
+ .webassets-cache
78
+
79
+ # Scrapy stuff:
80
+ .scrapy
81
+
82
+ # Sphinx documentation
83
+ docs/_build/
84
+
85
+ # PyBuilder
86
+ .pybuilder/
87
+ target/
88
+
89
+ # Jupyter Notebook
90
+ .ipynb_checkpoints
91
+
92
+ # IPython
93
+ profile_default/
94
+ ipython_config.py
95
+
96
+ # pyenv
97
+ # For a library or package, you might want to ignore these files since the code is
98
+ # intended to run in multiple environments; otherwise, check them in:
99
+ # .python-version
100
+
101
+ # pipenv
102
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
104
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
105
+ # install all needed dependencies.
106
+ #Pipfile.lock
107
+
108
+ # poetry
109
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
110
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
111
+ # commonly ignored for libraries.
112
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
113
+ #poetry.lock
114
+
115
+ # pdm
116
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
117
+ #pdm.lock
118
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
119
+ # in version control.
120
+ # https://pdm.fming.dev/#use-with-ide
121
+ .pdm.toml
122
+
123
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
124
+ __pypackages__/
125
+
126
+ # Celery stuff
127
+ celerybeat-schedule
128
+ celerybeat.pid
129
+
130
+ # SageMath parsed files
131
+ *.sage.py
132
+
133
+ # Environments
134
+ .env
135
+ .venv
136
+ env/
137
+ venv/
138
+ ENV/
139
+ env.bak/
140
+ venv.bak/
141
+
142
+ # Spyder project settings
143
+ .spyderproject
144
+ .spyproject
145
+
146
+ # Rope project settings
147
+ .ropeproject
148
+
149
+ # mkdocs documentation
150
+ /site
151
+
152
+ # mypy
153
+ .mypy_cache/
154
+ .dmypy.json
155
+ dmypy.json
156
+
157
+ # Pyre type checker
158
+ .pyre/
159
+
160
+ # pytype static type analyzer
161
+ .pytype/
162
+
163
+ # Cython debug symbols
164
+ cython_debug/
165
+
166
+ # PyCharm
167
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
168
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
169
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
170
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
171
+ #.idea/
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Training-Free Zero-Shot Semantic Segmentation With LLM Refinement
3
+ emoji: ⚡
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.38.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: agpl-3.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import spaces
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from omegaconf import OmegaConf
7
+
8
+ # set up environment
9
+ from utils.env_utils import set_random_seed, use_lower_vram
10
+ from utils.timer_utils import Timer
11
+
12
+ set_random_seed(1024)
13
+ timer = Timer()
14
+ timer.start()
15
+ # use_lower_vram()
16
+
17
+ # import functions
18
+ from utils.labels_utils import Labels
19
+ from utils.ram_utils import ram_inference
20
+ from utils.blip2_utils import blip2_caption
21
+ from utils.llms_utils import pre_refinement, make_prompt, init_model
22
+ from utils.grounded_sam_utils import run_grounded_sam
23
+
24
+
25
+ # hardcode parameters for G-SAM
26
+ box_threshold = 0.18
27
+ text_threshold = 0.15
28
+ iou_threshold = 0.8
29
+
30
+ global current_config, L, llm, system_prompt
31
+
32
+ # load Llama-3 here to avoid loading it during the inference.
33
+ llm = init_model("Meta-Llama-3-8B-Instruct")
34
+ current_config = ""
35
+ L = None
36
+ system_prompt = None
37
+
38
+ def load_config(config_type):
39
+ config = OmegaConf.load(os.path.join(os.path.dirname(__file__), f"configs/{config_type}.yaml"))
40
+ L = Labels(config=config)
41
+ # init labels and llm prompt, only Meta-Llama-3-8B-Instruct is supported for online demo, but you can use any model in your local environment using our released code
42
+ system_prompt = make_prompt(", ".join(L.LABELS))
43
+ return L, system_prompt
44
+
45
+ @spaces.GPU(duration=120)
46
+ def process(image_ori, config_type):
47
+ global current_config, L, llm, system_prompt
48
+ if current_config != config_type:
49
+ L, system_prompt = load_config(config_type)
50
+ current_config = config_type
51
+ else:
52
+ pass
53
+ image_ori = cv2.cvtColor(image_ori, cv2.COLOR_BGR2RGB)
54
+ image_pil = Image.fromarray(image_ori)
55
+ labels_ram = ram_inference(image_pil) + ": " + blip2_caption(image_pil)
56
+ converted_labels, llm_output = pre_refinement([labels_ram], system_prompt, llm=llm)
57
+ labels_llm = L.check_labels(converted_labels)[0]
58
+ print("labels_ram: ", labels_ram)
59
+ print("llm_output: ", llm_output)
60
+ print("labels_llm: ", labels_llm)
61
+
62
+ # run sam
63
+ label_res, bboxes, output_labels, output_prob_maps, output_points = run_grounded_sam(
64
+ input_image = {"image": image_pil, "mask": None},
65
+ text_prompt = labels_llm,
66
+ box_threshold = box_threshold,
67
+ text_threshold = text_threshold,
68
+ iou_threshold = iou_threshold,
69
+ LABELS = L.LABELS,
70
+ IDS = L.IDS,
71
+ llm = llm,
72
+ timer = timer,
73
+ )
74
+
75
+ # draw mask and save image
76
+ ours = L.draw_mask(label_res, image_ori, print_label=True, tag="Ours")
77
+ return cv2.cvtColor(ours, cv2.COLOR_BGR2RGB)
78
+
79
+
80
+ if __name__ == "__main__":
81
+ # options for different settings
82
+ dropdown_options = ["COCO-81", "Cityscapes", "DRAM", "VOC2012"]
83
+ default_option = "COCO-81"
84
+
85
+ with gr.Blocks() as demo:
86
+ gr.HTML(
87
+ """
88
+ <h1 style="text-align: center; font-size: 32px; font-family: 'Times New Roman', Times, serif;">
89
+ Training-Free Zero-Shot Semantic Segmentation with LLM Refinement
90
+ </h1>
91
+ <p style="text-align: center; font-size: 20px; font-family: 'Times New Roman', Times, serif;">
92
+ <a style="text-align: center; display:inline-block"
93
+ href="https://sky24h.github.io/websites/bmvc2024_training-free-semseg-with-LLM/">
94
+ <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/paper-page-sm.svg#center"
95
+ alt="Paper Page">
96
+ </a>
97
+ <a style="text-align: center; display:inline-block" href="https://huggingface.co/spaces/sky24h/Training-Free_Zero-Shot_Semantic_Segmentation_with_LLM_Refinement?duplicate=true">
98
+ <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg#center" alt="Duplicate Space">
99
+ </a>
100
+ </p>
101
+ """
102
+ )
103
+ gr.Interface(
104
+ fn=process,
105
+ inputs=[gr.Image(type="numpy", height="384"), gr.Dropdown(choices=dropdown_options, label="Refinement Type", value=default_option)],
106
+ outputs="image",
107
+ description="""<html>
108
+ <p style="text-align:center;"> This is an online demo for the paper "Training-Free Zero-Shot Semantic Segmentation with LLM Refinement" (BMVC 2024). </p>
109
+ <p style="text-align:center;"> Uasge: Please select or upload an image and choose a dataset setting for semantic segmentation refinement.</p>
110
+ </html>""",
111
+ allow_flagging='never',
112
+ examples=[
113
+ ["examples/Cityscapes_eg.png", "Cityscapes"],
114
+ ["examples/DRAM_eg.jpg", "DRAM"],
115
+ ["examples/COCO-81_eg.jpg", "COCO-81"],
116
+ ["examples/VOC2012_eg.jpg", "VOC2012"],
117
+ ],
118
+ cache_examples=True,
119
+ )
120
+
121
+ demo.queue(max_size=10).launch()
configs/COCO-81.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Name: COCO-81
2
+ label_list: "unlabeled, person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic-light, fire-hydrant, stop-sign, parking-meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag, tie, suitcase, frisbee, skis, snowboard, sports-ball, kite, baseball-bat, baseball-glove, skateboard, surfboard, tennis-racket, bottle, wine-glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot, hot-dog, pizza, donut, cake, chair, couch, potted-plant, bed, dining-table, toilet, tv, laptop, mouse, remote, keyboard, cell-phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy-bear, hair-drier, toothbrush"
3
+ mask_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
configs/Cityscapes.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Name: Cityscapes
2
+ label_list: "background, road, sidewalk, building, wall, fence, pole, traffic-light, traffic-sign, tree, terrain, sky, person, rider, car, truck, bus, train, motorcycle, bicycle"
3
+ mask_ids: [0, 7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
configs/DRAM.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Name: DRAM
2
+ label_list: "background, bird, boat, bottle, cat, chair, cow, dog, horse, person, potted-plant, sheep"
3
+ mask_ids: [0, 3, 4, 5, 8, 9, 10, 12, 13, 15, 16, 17]
configs/VOC2012.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Name: VOC2012
2
+ label_list: "background, aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow, dining-table, dog, horse, motorbike, person, potted-plant, sheep, sofa, train, monitor, void"
examples/COCO-81_eg.jpg ADDED
examples/Cityscapes_eg.jpg ADDED
examples/DRAM_eg.jpg ADDED
examples/VOC2012_eg.jpg ADDED
gradio_cached_examples/16/log.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ output,flag,username,timestamp
2
+ "{""path"": ""gradio_cached_examples/16/output/e9085a590715dd9a4cbc/image.webp"", ""url"": ""/file=/tmp/gradio/f17a9230acfa1f7c9d09b85c0c0528e64c5a19ec/image.webp"", ""size"": null, ""orig_name"": ""image.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}",,,2024-07-30 06:54:26.975686
3
+ "{""path"": ""gradio_cached_examples/16/output/ca6408b417ed4de51f74/image.webp"", ""url"": ""/file=/tmp/gradio/28f694172a8e086c7d12474e78b1e36453357589/image.webp"", ""size"": null, ""orig_name"": ""image.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}",,,2024-07-30 06:54:28.056743
4
+ "{""path"": ""gradio_cached_examples/16/output/6b3896574851c5665d17/image.webp"", ""url"": ""/file=/tmp/gradio/f2a070d0cd932cc4bf9ddb7f4cca22c01d4d4e37/image.webp"", ""size"": null, ""orig_name"": ""image.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}",,,2024-07-30 06:54:30.868722
5
+ "{""path"": ""gradio_cached_examples/16/output/8d66e32b3b15feb7ecc9/image.webp"", ""url"": ""/file=/tmp/gradio/9318730ca69938781665675ccbe76d635bc47a2d/image.webp"", ""size"": null, ""orig_name"": ""image.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}",,,2024-07-30 06:54:32.230374
gradio_cached_examples/16/output/6b3896574851c5665d17/image.webp ADDED
gradio_cached_examples/16/output/8d66e32b3b15feb7ecc9/image.webp ADDED
gradio_cached_examples/16/output/ca6408b417ed4de51f74/image.webp ADDED
gradio_cached_examples/16/output/e9085a590715dd9a4cbc/image.webp ADDED
pre-requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install transformers and timm in second stage to avoid error
2
+ torch==2.3.1 #pip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu121
3
+ torchvision==0.18.1
4
+ setuptools==69.5.1
5
+ gradio==4.38.1
6
+ openai>=1.0.0
7
+ opencv_python==4.8.1.78
8
+ diffusers[torch]==0.29.2
9
+ termcolor
10
+ fairscale
11
+ natsort
12
+ omegaconf
13
+ pycocotools
14
+ matplotlib
15
+ onnxruntime
16
+ onnx
17
+ groundingdino-py
18
+ segment_anything@git+https://github.com/SysCV/sam-hq.git
19
+ ram@git+https://github.com/xinyu1205/recognize-anything.git
pretrained-models/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size = 1
2
+ modelname = "groundingdino"
3
+ backbone = "swin_T_224_1k"
4
+ position_embedding = "sine"
5
+ pe_temperatureH = 20
6
+ pe_temperatureW = 20
7
+ return_interm_indices = [1, 2, 3]
8
+ backbone_freeze_keywords = None
9
+ enc_layers = 6
10
+ dec_layers = 6
11
+ pre_norm = False
12
+ dim_feedforward = 2048
13
+ hidden_dim = 256
14
+ dropout = 0.0
15
+ nheads = 8
16
+ num_queries = 900
17
+ query_dim = 4
18
+ num_patterns = 0
19
+ num_feature_levels = 4
20
+ enc_n_points = 4
21
+ dec_n_points = 4
22
+ two_stage_type = "standard"
23
+ two_stage_bbox_embed_share = False
24
+ two_stage_class_embed_share = False
25
+ transformer_activation = "relu"
26
+ dec_pred_bbox_embed_share = True
27
+ dn_box_noise_scale = 1.0
28
+ dn_label_noise_ratio = 0.5
29
+ dn_label_coef = 1.0
30
+ dn_bbox_coef = 1.0
31
+ embed_init_tgt = True
32
+ dn_labelbook_size = 2000
33
+ max_text_len = 256
34
+ text_encoder_type = "bert-base-uncased"
35
+ use_text_enhancer = True
36
+ use_fusion_layer = True
37
+ use_checkpoint = True
38
+ use_transformer_ckpt = True
39
+ use_text_cross_attention = True
40
+ text_dropout = 0.0
41
+ fusion_dropout = 0.0
42
+ fusion_droppath = 0.1
43
+ sub_sentence_present = True
pretrained-models/checkpoints/groundingdino_swint_ogc.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b3ca2563c77c69f651d7bd133e97139c186df06231157a64c507099c52bc799
3
+ size 693997677
pretrained-models/checkpoints/put pre-trained checkpoints here.txt ADDED
File without changes
pretrained-models/checkpoints/ram_plus_swin_large_14m.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:497c178836ba66698ca226c7895317e6e800034be986452dbd2593298d50e87d
3
+ size 3010210801
pretrained-models/checkpoints/sam_hq_vit_l.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1a6c385d62bf005ded91a54d5ec55c985cfc4103ef89c08d90f39f04934c343
3
+ size 1254865805
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Somehow, we needed to put transformers after GroundingDINO.
2
+ transformers==4.42.4
3
+ timm==1.0.8
utils/Arial.ttf ADDED
Binary file (276 kB). View file
 
utils/blip2_utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration # , BitsAndBytesConfig
4
+ from .env_utils import get_device, low_vram_mode
5
+
6
+ device = get_device()
7
+
8
+ blip2_model_id = "Salesforce/blip2-opt-2.7b" # or replace with your local model path
9
+ blip2_precision = torch.bfloat16
10
+
11
+ # Load BLIP2 model and processor from HuggingFace
12
+ blip2_processor = Blip2Processor.from_pretrained(blip2_model_id)
13
+ if low_vram_mode:
14
+ blip2_model = Blip2ForConditionalGeneration.from_pretrained(
15
+ blip2_model_id,
16
+ torch_dtype=blip2_precision,
17
+ device_map=device,
18
+ # quantization_config = BitsAndBytesConfig(load_in_8bit=True) if low_vram_mode else None, # ZeroGPU does not support quantization.
19
+ ).eval()
20
+ else:
21
+ blip2_model = Blip2ForConditionalGeneration.from_pretrained(blip2_model_id, torch_dtype=blip2_precision, device_map=device).eval()
22
+
23
+
24
+ def blip2_caption(raw_image):
25
+ # unconditional image captioning
26
+ inputs = blip2_processor(raw_image, return_tensors="pt")
27
+ inputs = inputs.to(device=device, dtype=blip2_precision)
28
+ out = blip2_model.generate(**inputs)
29
+ caption = blip2_processor.decode(out[0], skip_special_tokens=True)
30
+ return caption
31
+
32
+
33
+ # if __name__ == "__main__":
34
+ # from PIL import Image
35
+
36
+ # # Test the RAM++ model
37
+ # image_path = os.path.join(os.path.dirname(__file__), "../sources/test_imgs/1.jpg")
38
+ # image = Image.open(image_path)
39
+ # result = blip2_caption(image)
40
+ # print(result)
utils/env_utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Avoid multiple imports of the same module. Use this to import the module only once.
2
+ # Also, ensure that the device and pretrained models folder are consistent across the project.
3
+
4
+ import os
5
+ import torch
6
+
7
+ global low_vram_mode
8
+ low_vram_mode = False
9
+
10
+
11
+ def use_lower_vram():
12
+ global low_vram_mode
13
+ low_vram_mode = True
14
+
15
+
16
+ def get_device():
17
+ device = torch.device("cuda") # must use GPU in online demo version
18
+ return device
19
+
20
+
21
+ def set_random_seed(seed: int):
22
+ torch.manual_seed(seed)
23
+ torch.cuda.manual_seed(seed)
24
+ torch.cuda.manual_seed_all(seed)
25
+ torch.backends.cudnn.deterministic = True
26
+ torch.backends.cudnn.benchmark = False
27
+
28
+
29
+ def get_pretrained_models_folder():
30
+ return os.path.join(os.path.dirname(__file__), "../pretrained-models")
31
+
32
+
33
+ # def download_pretrained_models():
34
+ # pretrained_models_folder = get_pretrained_models_folder()
35
+ # # hard-coded download links
36
+ # groundingdino_link = "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth"
37
+ # sam_link = "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth"
38
+ # ram_link = "https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth"
39
+ # groundingdino_ckpt = os.path.join(pretrained_models_folder, "checkpoints/groundingdino_swint_ogc.pth")
40
+ # sam_ckpt = os.path.join(pretrained_models_folder, "checkpoints/sam_hq_vit_l.pth")
41
+ # ram_ckpt = os.path.join(pretrained_models_folder, "checkpoints/ram_plus_swin_large_14m.pth")
42
+
43
+ # # download pretrained models if not exists
44
+ # if not os.path.exists(groundingdino_ckpt):
45
+ # print(f"Downloading pretrained model: {groundingdino_ckpt}")
46
+ # os.system(f"wget -O {groundingdino_ckpt} {groundingdino_link} -q")
47
+ # if not os.path.exists(sam_ckpt):
48
+ # print(f"Downloading pretrained model: {sam_ckpt}")
49
+ # os.system(f"wget -O {sam_ckpt} {sam_link} -q")
50
+ # if not os.path.exists(ram_ckpt):
51
+ # print(f"Downloading pretrained model: {ram_ckpt}")
52
+ # os.system(f"wget -O {ram_ckpt} {ram_link} -q")
53
+
54
+
55
+ # # download pretrained models when imported
56
+ # download_pretrained_models()
utils/grounded_sam_utils.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import torchvision
5
+ import numpy as np
6
+ from PIL import Image, ImageFont
7
+ import traceback
8
+
9
+ # environment variables and paths
10
+ from .env_utils import get_device, get_pretrained_models_folder, low_vram_mode
11
+
12
+ device = get_device()
13
+ pretrained_models_folder = get_pretrained_models_folder()
14
+ groundingdino_ckpt = os.path.join(pretrained_models_folder, "checkpoints/groundingdino_swint_ogc.pth")
15
+ sam_ckpt = os.path.join(pretrained_models_folder, "checkpoints/sam_hq_vit_l.pth")
16
+
17
+ # segment anything
18
+ from segment_anything import build_sam_vit_l, SamPredictor
19
+
20
+ # Grounding DINO
21
+ from groundingdino.models import build_model
22
+ from groundingdino.util.slconfig import SLConfig
23
+ from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
24
+ import groundingdino.datasets.transforms as T
25
+
26
+ font_family = os.path.join(os.path.dirname(__file__), "Arial.ttf")
27
+ font_size = 24
28
+ font = ImageFont.truetype(font_family, font_size)
29
+
30
+ from .llms_utils import post_refinement
31
+
32
+
33
+ def draw_bboxes(ours_bboxes, output_labels, bboxes, output_points, output_prob_maps):
34
+ # draw bboxes on the image
35
+ for label, bbox in zip(output_labels, bboxes):
36
+ bbox = bbox.cpu().numpy()
37
+ bbox = [int(round(bbox[0])), int(round(bbox[1])), int(round(bbox[2])), int(round(bbox[3]))]
38
+ # print("label, bbox", label, bbox)
39
+ cv2.rectangle(ours_bboxes, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
40
+ # caption inside the bbox, below the top left corner 20 pixels
41
+ cv2.putText(ours_bboxes, label, (bbox[0], bbox[1] + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
42
+ try:
43
+ for points in output_points:
44
+ for point in points:
45
+ # draw a cross on the point
46
+ cv2.drawMarker(ours_bboxes, (int(point[0]), int(point[1])), (0, 0, 255), cv2.MARKER_CROSS, 10, 2)
47
+ except: # noqa
48
+ pass
49
+
50
+ # Draw the probability maps
51
+ # if output_prob_maps is not None:
52
+ # output_prob_maps = np.concatenate(output_prob_maps, axis=1)
53
+ # ours_bboxes = np.concatenate([output_prob_maps, ours_bboxes], axis=1)
54
+ return ours_bboxes
55
+
56
+
57
+ def transform_image(image_pil):
58
+ transform = T.Compose(
59
+ [
60
+ T.RandomResize([800], max_size=1333),
61
+ T.ToTensor(),
62
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
63
+ ]
64
+ )
65
+ image, _ = transform(image_pil, None) # 3, h, w
66
+ return image
67
+
68
+
69
+ def _load_model(model_config_path, model_checkpoint_path, device):
70
+ args = SLConfig.fromfile(model_config_path)
71
+ args.device = device
72
+ model = build_model(args)
73
+ model.load_state_dict(clean_state_dict(torch.load(model_checkpoint_path, map_location="cpu")["model"]), strict=False)
74
+ return model.to(device=device).eval()
75
+
76
+
77
+ def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True):
78
+ caption = caption.lower()
79
+ caption = caption.strip()
80
+ if not caption.endswith("."):
81
+ caption = caption + "."
82
+
83
+ with torch.no_grad():
84
+ outputs = model(image[None], captions=[caption])
85
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
86
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
87
+ logits.shape[0]
88
+
89
+ # filter output
90
+ logits_filt = logits.clone()
91
+ boxes_filt = boxes.clone()
92
+ filt_mask = logits_filt.max(dim=1)[0] > box_threshold
93
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
94
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
95
+ logits_filt.shape[0]
96
+
97
+ # get phrase
98
+ tokenlizer = model.tokenizer
99
+ tokenized = tokenlizer(caption)
100
+ # build pred
101
+ pred_phrases = []
102
+ scores = []
103
+ for logit, box in zip(logits_filt, boxes_filt):
104
+ pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
105
+ if with_logits:
106
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
107
+ else:
108
+ pred_phrases.append(pred_phrase)
109
+ scores.append(logit.max().item())
110
+ return boxes_filt, torch.Tensor(scores), pred_phrases
111
+
112
+
113
+ def postprocess_masks(input_masks, input_pred_phrases):
114
+ input_masks_ = input_masks.cpu().numpy().transpose(0, 2, 3, 1).copy()
115
+ output_masks = input_masks.cpu().numpy().transpose(0, 2, 3, 1).copy()
116
+ for i in range(len(output_masks)):
117
+ for j in range(len(output_masks)):
118
+ if i == j:
119
+ continue
120
+ if ((input_masks_[i] * input_masks_[j]).sum() > 0) and (input_pred_phrases[i].split("(")[0] != input_pred_phrases[j].split("(")[0]):
121
+ # if two masks overlap and have different labels
122
+ if float(input_pred_phrases[i].split("(")[1].split(")")[0]) < float(input_pred_phrases[j].split("(")[1].split(")")[0]):
123
+ # if the score of the first mask is lower than the second mask, remove overlapping area from the first mask
124
+ output_masks[i] = np.logical_and(output_masks[i], np.logical_not(input_masks_[j]))
125
+ else:
126
+ # otherwise, remove overlapping area from the second mask
127
+ output_masks[j] = np.logical_and(output_masks[j], np.logical_not(input_masks_[i]))
128
+ return output_masks.transpose(3, 0, 1, 2)[0]
129
+
130
+
131
+ groundingdino_model = None
132
+ sam_predictor = None
133
+ already_converted = {}
134
+ config_file = os.path.join(pretrained_models_folder, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py")
135
+
136
+
137
+ def _find_higest_points(logits_map, num_top_points=20):
138
+ if num_top_points == 0:
139
+ return logits_map, []
140
+ # find the highest points on the logits map
141
+ gray = cv2.cvtColor(logits_map, cv2.COLOR_BGR2GRAY).astype("uint8")
142
+ # find the highest points
143
+ points = []
144
+ for i in range(num_top_points):
145
+ y, x = np.unravel_index(np.argmax(gray, axis=None), gray.shape)
146
+ points.append((x, y))
147
+ gray[y, x] = 0
148
+ # draw points
149
+ for point in points:
150
+ cv2.drawMarker(logits_map, point, (0, 0, 255), cv2.MARKER_CROSS, 10, 3)
151
+ return logits_map, points
152
+
153
+
154
+ def _find_contour_points(logits_map, num_points=5):
155
+ if num_points == 0:
156
+ return logits_map, []
157
+ # find contours and get number of points on the contour, then draw the points on the image
158
+ gray = cv2.cvtColor(logits_map, cv2.COLOR_BGR2GRAY).astype("uint8")
159
+ ret, thresh = cv2.threshold(gray, 155, 255, 0)
160
+ # erode to make the contour thinner
161
+ kernel = np.ones((13, 13), np.uint8)
162
+ # only apply erode when the image is large enough, otherwise, skip it
163
+ if np.sum(thresh) > (gray.shape[0] * gray.shape[1] * 255 * 0.1):
164
+ erode_iterations = int(np.log2(min(gray.shape[0], gray.shape[1])) - 1)
165
+ thresh = cv2.erode(thresh, kernel, iterations=erode_iterations)
166
+
167
+ # only use the largest contour
168
+ contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
169
+ largest_contour = max(contours, key=cv2.contourArea)
170
+
171
+ points = []
172
+ if len(largest_contour) > num_points:
173
+ for i in range(0, len(largest_contour), len(largest_contour) // num_points):
174
+ if len(points) == num_points:
175
+ break
176
+ x, y = largest_contour[i][0]
177
+ points.append((x, y))
178
+
179
+ # make sure the points are at the same number as num_points
180
+ if len(points) == 0:
181
+ raise ValueError("no points found")
182
+ elif len(points) < num_points:
183
+ for i in range(num_points - len(points)):
184
+ points.append(points[-1])
185
+ elif len(points) > num_points:
186
+ points = points[:num_points]
187
+ else:
188
+ pass
189
+ # draw points
190
+ for point in points:
191
+ # cv2.circle(logits_map, point, 3, (0, 0, 255), -1)
192
+ cv2.drawMarker(logits_map, point, (0, 0, 255), cv2.MARKER_CROSS, 10, 3)
193
+
194
+ return logits_map, points
195
+
196
+
197
+ def _process_logits(logits, pred_phrases, top_n_points):
198
+ # print("logits", logits.shape)
199
+ # torch.Size([3, 1, 468, 500])
200
+ logits = logits.cpu().numpy()[:, 0, :, :]
201
+ logits = ((logits - np.min(logits)) / (np.max(logits) - np.min(logits))) * 255
202
+ logits_maps = []
203
+ points_list = []
204
+ for i, logits_map in enumerate(logits):
205
+ try:
206
+ logits_map = cv2.cvtColor(np.array(logits_map, dtype=np.uint8), cv2.COLOR_GRAY2BGR)
207
+ logits_map, points = _find_higest_points(logits_map, num_top_points=top_n_points)
208
+ if len(points) == 0:
209
+ points = None
210
+ cv2.putText(logits_map, pred_phrases[i], (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
211
+ logits_maps.append(logits_map)
212
+ points_list.append(points)
213
+ except Exception as e:
214
+ print("error in _process_logits", e)
215
+ continue
216
+ return logits_maps, points_list
217
+
218
+
219
+ def run_grounded_sam(
220
+ input_image,
221
+ text_prompt,
222
+ box_threshold,
223
+ text_threshold,
224
+ iou_threshold,
225
+ LABELS = [],
226
+ IDS = [],
227
+ llm = None,
228
+ timer = None,
229
+ # for ablation study
230
+ wo_post = False,
231
+ top_n_points = 20,
232
+ ):
233
+ global groundingdino_model, sam_predictor, already_converted
234
+
235
+ # load image
236
+ image_pil = input_image["image"].convert("RGB")
237
+ transformed_image = transform_image(image_pil).to(device=device)
238
+ size = image_pil.size
239
+
240
+ if groundingdino_model is None:
241
+ groundingdino_model = _load_model(config_file, groundingdino_ckpt, device=device)
242
+
243
+ # run grounding dino model
244
+ boxes_filt, scores, pred_phrases = get_grounding_output(groundingdino_model, transformed_image, text_prompt, box_threshold, text_threshold)
245
+ timer.check("get_grounding_output")
246
+
247
+ # process boxes
248
+ H, W = size[1], size[0]
249
+ for i in range(boxes_filt.size(0)):
250
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
251
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
252
+ boxes_filt[i][2:] += boxes_filt[i][:2]
253
+ boxes_filt = boxes_filt.cpu()
254
+
255
+ # nms
256
+ nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
257
+ boxes_filt = boxes_filt[nms_idx]
258
+ pred_phrases = [pred_phrases[idx] for idx in nms_idx]
259
+
260
+ if sam_predictor is None:
261
+ # initialize SAM
262
+ assert sam_ckpt, "sam_ckpt is not found!"
263
+ sam = build_sam_vit_l(checkpoint=sam_ckpt)
264
+ sam.to(device=device).eval()
265
+ sam_predictor = SamPredictor(sam)
266
+ sam_predictor.model.to(device=device)
267
+ image = np.array(image_pil)
268
+ sam_predictor.set_image(image)
269
+
270
+ input_box = torch.tensor(boxes_filt, device=device)
271
+ transformed_boxes = sam_predictor.transform.apply_boxes_torch(input_box, image.shape[:2])
272
+ logits, _, _ = sam_predictor.predict_torch(
273
+ point_coords = None,
274
+ point_labels = None,
275
+ boxes = transformed_boxes,
276
+ multimask_output = False,
277
+ return_logits = True,
278
+ hq_token_only = False,
279
+ )
280
+ timer.check("get prob")
281
+
282
+ output_prob_maps, output_points = _process_logits(logits, pred_phrases, top_n_points=top_n_points)
283
+ if top_n_points == 0:
284
+ # processing without points prompt, for ablation study
285
+ print("processing without points prompt, for ablation study")
286
+ point_coords = None
287
+ point_labels = None
288
+ else:
289
+ if None in output_points:
290
+ point_coords = None
291
+ point_labels = None
292
+ else:
293
+ point_coords = torch.tensor(np.array(output_points), device=device)
294
+ point_coords = sam_predictor.transform.apply_coords_torch(point_coords, image.shape[:2])
295
+ point_labels = torch.ones(point_coords.shape[:2], device=device)
296
+ # print("point_coords", point_coords.shape, point_labels.shape, transformed_boxes.shape)
297
+ transformed_boxes = transformed_boxes[: point_coords.shape[0]]
298
+
299
+ masks, _, _ = sam_predictor.predict_torch(
300
+ point_coords = point_coords,
301
+ point_labels = point_labels,
302
+ boxes = transformed_boxes,
303
+ multimask_output = False,
304
+ hq_token_only = False,
305
+ )
306
+ masks = postprocess_masks(masks, pred_phrases)
307
+ timer.check("postprocess_masks")
308
+
309
+ label_image = Image.new("L", size, color=0)
310
+ label_draw = np.array(label_image)
311
+ output_labels = []
312
+ for mask, pred_phrase in zip(masks, pred_phrases):
313
+ try:
314
+ label = pred_phrase.split("(")[0]
315
+ if label in ["", " "]:
316
+ # skip empty label
317
+ continue
318
+ elif label in LABELS:
319
+ # no need to convert if it's one of the target labels
320
+ post_label = label
321
+ elif label in already_converted:
322
+ # check if the label was converted before to save time and model calls
323
+ post_label = already_converted[label]
324
+ print("already converted: {} to {}".format(label, already_converted[label]))
325
+ else:
326
+ # convert the label using llm model
327
+ label = label.replace(" ", "") if "-" in label else label
328
+ if wo_post:
329
+ print("wo_post is True, for ablation study")
330
+ # skip post refinement, for ablation study
331
+ post_label = label
332
+ else:
333
+ post_label = post_refinement(LABELS, label, llm=llm)
334
+ print("convert from {} to {}".format(label, post_label))
335
+ # add to the already_converted list, no matter it's in the list or not to save $!
336
+ already_converted.update({label: post_label})
337
+ if post_label not in LABELS:
338
+ raise ValueError("label not found, {} from {}".format(post_label, label))
339
+ output_labels.append(post_label)
340
+ label_index = LABELS.index(post_label)
341
+ label_draw[mask] = IDS[label_index]
342
+ except ValueError as e:
343
+ print("e", e)
344
+ print("label not found: ", pred_phrase)
345
+ traceback.print_exc()
346
+ continue
347
+ timer.check("llm+draw label")
348
+ return label_draw, boxes_filt, output_labels, output_prob_maps, output_points
utils/labels_utils.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ COCO_CATEGORIES = [
5
+ # borrowed from detectron2
6
+ # https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/datasets/register_coco_stuff_10k.py
7
+ {"color": [0, 0, 0], "isthing": 0, "id": 0, "name": "unlabeled"},
8
+ {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
9
+ {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
10
+ {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
11
+ {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
12
+ {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
13
+ {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
14
+ {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
15
+ {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
16
+ {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
17
+ {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
18
+ {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
19
+ {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
20
+ {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
21
+ {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
22
+ {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
23
+ {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
24
+ {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
25
+ {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
26
+ {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
27
+ {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
28
+ {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
29
+ {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
30
+ {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
31
+ {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
32
+ {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
33
+ {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
34
+ {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
35
+ {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
36
+ {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
37
+ {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
38
+ {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
39
+ {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
40
+ {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
41
+ {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
42
+ {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
43
+ {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
44
+ {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
45
+ {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
46
+ {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
47
+ {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
48
+ {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
49
+ {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
50
+ {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
51
+ {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
52
+ {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
53
+ {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
54
+ {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
55
+ {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
56
+ {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
57
+ {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
58
+ {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
59
+ {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
60
+ {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
61
+ {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
62
+ {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
63
+ {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
64
+ {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
65
+ {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
66
+ {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
67
+ {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
68
+ {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
69
+ {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
70
+ {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
71
+ {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
72
+ {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
73
+ {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
74
+ {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
75
+ {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
76
+ {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
77
+ {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
78
+ {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
79
+ {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
80
+ {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
81
+ {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
82
+ {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
83
+ {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
84
+ {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
85
+ {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
86
+ {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
87
+ {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
88
+ ]
89
+
90
+
91
+ def create_coco_colormap(IDs):
92
+ all_colors = []
93
+ vis_colors = [category["color"] for category in COCO_CATEGORIES]
94
+ used_ids = [category["id"] for category in COCO_CATEGORIES]
95
+ all_colors = [vis_colors[used_ids.index(id)] if id in used_ids else [0, 0, 0] for id in range(max(IDs)+1)]
96
+ return np.array(all_colors, dtype=int)
97
+
98
+
99
+ def create_cityscapes_colormap(IDs):
100
+ vis_colors = [
101
+ (0, 0, 0),
102
+ (128, 64, 128),
103
+ (244, 35, 232),
104
+ (70, 70, 70),
105
+ (102, 102, 156),
106
+ (190, 153, 153),
107
+ (153, 153, 153),
108
+ (250, 170, 30),
109
+ (220, 220, 0),
110
+ (107, 142, 35),
111
+ (152, 251, 152),
112
+ (70, 130, 180),
113
+ (220, 20, 60),
114
+ (255, 0, 0),
115
+ (0, 0, 142),
116
+ (0, 0, 70),
117
+ (0, 60, 100),
118
+ (0, 80, 100),
119
+ (0, 0, 230),
120
+ (119, 11, 32),
121
+ ]
122
+
123
+ all_colors = [vis_colors[IDs.index(id)] if id in IDs else [0, 0, 0] for id in range(max(IDs)+1)]
124
+ return np.array(all_colors, dtype=int)
125
+
126
+ def create_pascal_label_colormap(n_labels=256):
127
+ def bitget(byteval, idx):
128
+ return ((byteval & (1 << idx)) != 0)
129
+
130
+ cmap = np.zeros((n_labels, 3), dtype=np.uint8)
131
+ for i in range(n_labels):
132
+ r = g = b = 0
133
+ c = i
134
+ for j in range(8):
135
+ r = r | (bitget(c, 0) << 7-j)
136
+ g = g | (bitget(c, 1) << 7-j)
137
+ b = b | (bitget(c, 2) << 7-j)
138
+ c = c >> 3
139
+ cmap[i] = np.array([r, g, b])
140
+ return cmap
141
+
142
+
143
+
144
+ class Labels:
145
+ def __init__(self, config=None):
146
+ max_label_num = 200
147
+ if config is not None:
148
+ self.LABELS = config.label_list.split(", ")
149
+ self.IDS = config.mask_ids if hasattr(config, "mask_ids") else [i for i in range(len(self.LABELS))]
150
+ print("self.IDS", self.IDS)
151
+ if len(self.LABELS) > max_label_num:
152
+ raise ValueError(f"Too many labels! The maximum number of labels is {max_label_num}.")
153
+ else:
154
+ raise NotImplementedError("config is None")
155
+
156
+ if "COCO" in config.Name:
157
+ self.COLORS = create_coco_colormap(self.IDS)
158
+ elif "City" in config.Name:
159
+ self.COLORS = create_cityscapes_colormap(self.IDS)
160
+ else:
161
+ # default to pascal label colormap
162
+ self.COLORS = create_pascal_label_colormap()
163
+
164
+ assert len(self.COLORS) >= len(self.LABELS), f"len(self.COLORS)={len(self.COLORS)} < len(self.LABELS)={len(self.LABELS)}"
165
+
166
+ def check_labels(self, labels_list):
167
+ output_labels_list = []
168
+ for labels in labels_list:
169
+ output_labels = []
170
+ labels = labels.split(", ")
171
+ for label in labels:
172
+ if label == "background":
173
+ # skip the background label
174
+ continue
175
+ if label in self.LABELS:
176
+ output_labels.append(label)
177
+ output_labels = list(set(output_labels))
178
+ output_labels_list.append(", ".join(output_labels))
179
+ return output_labels_list
180
+
181
+ def draw_mask(self, label_ori, image_ori, print_label=False, tag="", only_label=False):
182
+ label_ori = label_ori.astype(np.uint8)
183
+ label = np.zeros_like(image_ori, dtype=np.uint8)
184
+ # print("{}: {}".format(tag, np.unique(label_ori)))
185
+ for id in np.unique(label_ori):
186
+ # print("id", id)
187
+ if id == 0 or id == 255:
188
+ continue
189
+ elif id not in self.IDS:
190
+ print(f"Label {id} is not in the label list.")
191
+ continue
192
+ i = self.IDS.index(id)
193
+ center = np.mean(np.argwhere(label_ori == id), axis=0).astype(np.int64)
194
+ label[label_ori == id] = self.COLORS[id]
195
+ if print_label:
196
+ # add text in the center of the mask
197
+ cv2.putText(label, self.LABELS[i], (center[1], center[0]), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
198
+ # print(i, self.LABELS[i])
199
+ # RGB to BGR
200
+ label = cv2.cvtColor(label, cv2.COLOR_RGB2BGR)
201
+ return cv2.addWeighted(label, 0.6, image_ori, 0.4, 0) if not only_label else label
202
+
203
+ def find_gt_labels(self, label_gt):
204
+ label_gt = label_gt.astype(np.uint8)
205
+ label_gt_list = []
206
+ for id in np.unique(label_gt):
207
+ if id == 0 or id == 255:
208
+ continue
209
+ elif id not in self.IDS:
210
+ print(f"Label {id} is not in the label list.")
211
+ continue
212
+ i = self.IDS.index(id)
213
+ label_gt_list.append(self.LABELS[i])
214
+ return label_gt_list
utils/llms_utils.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ from openai import OpenAI
5
+ from termcolor import colored
6
+
7
+ import transformers
8
+ # from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
9
+ from huggingface_hub import login
10
+
11
+ # environment variables and paths
12
+ from .env_utils import get_device, low_vram_mode
13
+
14
+ device = get_device()
15
+
16
+ class GPT:
17
+ def __init__(self, model="gpt-4o-mini", api_key=None):
18
+ self.prices = {
19
+ # check at https://openai.com/api/pricing/
20
+ "gpt-3.5-turbo-0125": [0.0000005, 0.0000015],
21
+ "gpt-4o-mini" : [0.00000015, 0.00000060],
22
+ "gpt-4-1106-preview": [0.00001, 0.00003],
23
+ "gpt-4-0125-preview": [0.00001, 0.00003],
24
+ "gpt-4-turbo" : [0.00001, 0.00003],
25
+ "gpt-4o" : [0.000005, 0.000015],
26
+ }
27
+ self.cheaper_model = "gpt-4o-mini"
28
+ assert model in self.prices.keys(), "Invalid model, please choose from: {}, or add new models in the code.".format(self.prices.keys())
29
+ self.model = model
30
+ print(f"Using {model}")
31
+ self.client = OpenAI(api_key=api_key)
32
+ self.total_cost = 0.0
33
+
34
+ def _update(self, response, price):
35
+ current_cost = response.usage.completion_tokens * price[0] + response.usage.prompt_tokens * price[1]
36
+ self.total_cost += current_cost
37
+ # print in 4 decimal places
38
+ print(
39
+ colored(
40
+ f"Current Tokens: {response.usage.completion_tokens + response.usage.prompt_tokens:d} \
41
+ Current cost: {current_cost:.4f} $, \
42
+ Total cost: {self.total_cost:.4f} $",
43
+ "yellow",
44
+ )
45
+ )
46
+
47
+ def chat(self, messages, temperature=0.0, max_tokens=200, post=False):
48
+ # set temperature to 0.0 for more deterministic results
49
+ if post:
50
+ # use cheaper model for post-refinement to save costs, since the task is simpler.
51
+ generated_text = self.client.chat.completions.create(
52
+ model=self.cheaper_model, messages=messages, temperature=temperature, max_tokens=max_tokens
53
+ )
54
+ self._update(generated_text, self.prices[self.cheaper_model])
55
+ else:
56
+ generated_text = self.client.chat.completions.create(
57
+ model=self.model, messages=messages, temperature=temperature, max_tokens=max_tokens
58
+ )
59
+ self._update(generated_text, self.prices[self.model])
60
+ generated_text = generated_text.choices[0].message.content
61
+ return generated_text
62
+
63
+
64
+ class Llama3:
65
+ def __init__(self, model="Meta-Llama-3-8B-Instruct"):
66
+ login(token=os.getenv('HF_TOKEN'))
67
+ model = "meta-llama/{}".format(model) # or replace with your local model path
68
+ print(f"Using {model}")
69
+ # ZeroGPU does not support quantization.
70
+ # tokenizer = AutoTokenizer.from_pretrained(model)
71
+ # if low_vram_mode:
72
+ # model = AutoModelForCausalLM.from_pretrained(
73
+ # model, quantization_config=BitsAndBytesConfig(load_in_8bit=True), device_map="auto"
74
+ # ).eval()
75
+ self.pipeline = transformers.pipeline(
76
+ "text-generation",
77
+ model = model,
78
+ # tokenizer = tokenizer,
79
+ model_kwargs = {"torch_dtype": torch.bfloat16},
80
+ device_map = "auto",
81
+ )
82
+ self.terminators = [self.pipeline.tokenizer.eos_token_id, self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
83
+
84
+ def _update(self):
85
+ print(colored("Using Llama-3, Free", "green"))
86
+
87
+ def chat(self, messages, temperature=0.0, max_tokens=200, post=False):
88
+ prompt = self.pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
89
+ generated_text = self.pipeline(
90
+ prompt,
91
+ max_new_tokens = max_tokens,
92
+ eos_token_id = self.terminators,
93
+ pad_token_id = 128001,
94
+ do_sample = True,
95
+ temperature = max(temperature, 0.01), # 0.0 is not supported
96
+ top_p = 0.9,
97
+ )
98
+ self._update()
99
+ generated_text = generated_text[0]["generated_text"][len(prompt) :]
100
+ return generated_text
101
+
102
+
103
+ # Define the timeout handler
104
+ def timeout_handler(signum, frame):
105
+ raise TimeoutError()
106
+
107
+
108
+ def init_model(model, api_key=None):
109
+ if "gpt" in model:
110
+ return GPT(model=model, api_key=api_key)
111
+ elif "Llama" in model:
112
+ return Llama3(model=model)
113
+ else:
114
+ raise ValueError("Invalid model")
115
+
116
+
117
+ def _generate_example_prompt(examples, llm=None):
118
+ # system prompt
119
+ system_prompt = """
120
+ Task Description:
121
+ - you will provide detailed explanations for example inputs and outputs within the context of the task.
122
+
123
+ Please adhere to the following rules:
124
+ - Exclude terms that appear in both lists.
125
+ - Detail the relevance of unmatched terms from input to output, focusing on indirect relationships.
126
+ - Identify and explain terms common to all output lists but rarely present in input lists; include these at the end of the output labeled 'Recommend Include Labels'.
127
+ - Each explanation should be concise, around 50 words.
128
+
129
+ Output Format:
130
+ - '1. Input... Output... Explanation... n. Input... Output... Explanation... \n Recommend Include Labels: label1, labeln, ...'
131
+ """
132
+ messages = [
133
+ {"role": "system", "content": system_prompt},
134
+ {
135
+ "role": "user",
136
+ "content": f"Here are the input and output lists for which you need to provide detailed explanations:{examples.strip()}",
137
+ },
138
+ ]
139
+ generated_example = llm.chat(messages, temperature=0.0, max_tokens=1000)
140
+ return generated_example
141
+
142
+
143
+ def _make_prompt(label_list, example=None):
144
+ Cityscape = "sidewalk" in label_list
145
+ if Cityscape:
146
+ add_text = f'contain at least {len(label_list.split(", "))} labels, '
147
+ else:
148
+ add_text = ""
149
+ # Task description and instructions for processing the input to generate output
150
+ system_prompt = f"""
151
+ Task Description:
152
+ - You will receive a list of caption tags accompanied by a caption text and must assign appropriate labels from a predefined label list: "{label_list}".
153
+
154
+ Instructions:
155
+ Step 1. Visualize the scene suggested by the input caption tags and text.
156
+ Step 2. Analyze each term within the overall scene to predict relevant labels from the predefined list, ensuring no term is overlooked.
157
+ Step 3. Now forget the input list and focus on the scene as a whole, expanding upon the labels to include any contextually relevant labels that complete the scene or setting.
158
+ Step 4. Compile all identified labels into a comma-separated list, adhering strictly to the specified format.
159
+
160
+ Contextually Relevant Tips:
161
+ - Equivalencies include converting "girl, man" to "person" and "flower, vase" to "potted plant", while "bicycle, motorcycle" suggest "rider".
162
+ - An outdoor scene may include labels like "sky", "tree", "clouds", "terrain".
163
+ - An urban scene may imply "bus", "bicycle", "road", "sidewalk", "building", "pole", "traffic-light", "traffic-sign".
164
+
165
+ Output:
166
+ - Do not output any explanations other than the final label list.
167
+ - The final output should {add_text}strictly adhere to the specified format: label1, label2, ... labeln
168
+ """.strip()
169
+ if example:
170
+ system_prompt += f"""
171
+ Additional Examples with Detailed Explanations:
172
+ {example}
173
+ """
174
+ print("system_prompt: ", system_prompt)
175
+ return system_prompt
176
+
177
+ # - You will receive a list of terms accompanied by a caption text and must assign appropriate labels from a predefined label list: "{label_list}".
178
+
179
+ # Instructions:
180
+ # Step 1. Visualize the scene suggested by the input list and caption text.
181
+
182
+
183
+ def make_prompt(label_list):
184
+ # Create a new system prompt using the label list and the improved example prompt
185
+ system_prompt = _make_prompt(label_list)
186
+ system_prompt = {"role": "system", "content": system_prompt.strip()}
187
+ print("system_prompt: ", system_prompt)
188
+ return system_prompt
189
+
190
+
191
+ def _call_llm(system_prompt, llm, user_input):
192
+ messages = [system_prompt, {"role": "user", "content": "Here are input caption tags and text: " + user_input}]
193
+ converted_label = llm.chat(messages=messages, temperature=0.0, max_tokens=200)
194
+ return converted_label
195
+
196
+
197
+ def pre_refinement(user_input_list, system_prompt, llm=None):
198
+ llm_outputs = [_call_llm(system_prompt, llm, user_input) for user_input in user_input_list]
199
+ converted_labels = [f"{user_input_}, {converted_label}" for user_input_, converted_label in zip(user_input_list, llm_outputs)]
200
+ return converted_labels, llm_outputs
201
+
202
+
203
+ def post_refinement(label_list, detected_label, llm=None):
204
+ system_input = f"""
205
+ Task Description:
206
+ - You will receive a specific phrase and must assign an appropriate label from the predefined label list: "{label_list}". \n \
207
+
208
+ Please adhere to the following rules: \n \
209
+ - Select and return only one relevant label from the predefined label list that corresponds to the given phrase. \n \
210
+ - Do not include any additional information or context beyond the label itself. \n \
211
+ - Format is purely the label itself, without any additional punctuation or formatting. \n \
212
+ """
213
+ system_input = {"role": "system", "content": system_input}
214
+ messages = [system_input, {"role": "user", "content": detected_label}]
215
+ if detected_label == "":
216
+ return ""
217
+ generated_label = None
218
+ for count in range(3):
219
+ generated_label = llm.chat(messages=messages, temperature=0.0 if count == 0 else 0.1 * (count), post=True)
220
+ if generated_label != "":
221
+ break
222
+ return generated_label
223
+
224
+
225
+ if __name__ == "__main__":
226
+ # test the functions
227
+ llm = Llama3(model="Meta-Llama-3-8B-Instruct")
228
+
229
+ system_prompt = make_prompt("person, car, tree, sky, road, building, sidewalk, traffic-light, traffic-sign", llm=llm)
230
+
231
+ converted_labels, llm_outputs = pre_refinement(["person, car, road, traffic-light"], system_prompt, llm=llm)
232
+ print("converted_labels: ", converted_labels)
233
+ print("llm_outputs: ", llm_outputs)
utils/ram_utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ from .env_utils import get_device, low_vram_mode
5
+
6
+ device = get_device()
7
+
8
+ pretrained_models_folder = os.path.join(os.path.dirname(__file__), "../pretrained-models")
9
+
10
+
11
+ # RAM++
12
+ from ram.models import ram_plus
13
+ from ram import get_transform, inference_ram
14
+
15
+ ram_ckpt = os.path.join(pretrained_models_folder, "checkpoints/ram_plus_swin_large_14m.pth")
16
+ ram_precision = torch.bfloat16
17
+
18
+
19
+ def ram_init():
20
+ image_size = 384
21
+ transform = get_transform(image_size=image_size)
22
+ #######load model#######
23
+ model = ram_plus(pretrained=ram_ckpt, image_size=image_size, vit="swin_l")
24
+ model = model.to(device=device, dtype=ram_precision)
25
+ model.eval()
26
+ print("RAM++ model loaded")
27
+ return model, transform
28
+
29
+
30
+ # Initialize the model when importing the module
31
+ ram_model, ram_transform = ram_init()
32
+
33
+
34
+ def _inference(image_pil):
35
+ image = ram_transform(image_pil).unsqueeze(0)
36
+ image = image.to(device=device, dtype=ram_precision)
37
+ res = inference_ram(image, ram_model)
38
+ result = res[0].replace(" | ", ", ")
39
+ return result
40
+
41
+
42
+ def _split_large_image(image_pil):
43
+ size = image_pil.size
44
+ print("Image size is too large, split into smaller patches")
45
+ # Split the image into 4 patches
46
+ patches = []
47
+ patch_size = (size[0] // 2, size[1] // 2)
48
+ for i in range(2):
49
+ for j in range(2):
50
+ left = i * patch_size[0]
51
+ top = j * patch_size[1]
52
+ right = left + patch_size[0]
53
+ bottom = top + patch_size[1]
54
+ patch = image_pil.crop((left, top, right, bottom))
55
+ patches.append(patch)
56
+ return patches
57
+
58
+
59
+ def ram_inference(image_pil: Image.Image):
60
+ size = image_pil.size
61
+ if size[0] > 640 or size[1] > 640:
62
+ # split only once in the online demo version.
63
+ patches = _split_large_image(image_pil)
64
+ # while any(patch.size[0] > 640 or patch.size[1] > 640 for patch in patches):
65
+ # patches = [_split_large_image(patch) for patch in patches]
66
+ # patches = [patch for sublist in patches for patch in sublist]
67
+ # Inference on each patch
68
+ results = []
69
+ for patch in patches:
70
+ result = _inference(patch)
71
+ results.extend(result.split(", "))
72
+ results = list(set(results))
73
+ # Combine the results
74
+ final_result = ", ".join(results)
75
+ return final_result
76
+ else:
77
+ print("Image size is small enough for inference")
78
+ return _inference(image_pil)
79
+
80
+
81
+ if __name__ == "__main__":
82
+ # Test the RAM++ model
83
+ image_path = os.path.join(os.path.dirname(__file__), "../sources/test_imgs/1.jpg")
84
+ image = Image.open(image_path)
85
+ result = ram_inference(image)
86
+ print(result)
utils/timer_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+
5
+
6
+ def create_logger(logger_name: str, log_file_path: os.PathLike = None):
7
+ """
8
+ Create a logger with the specified name and log file path.
9
+ """
10
+ logger = logging.getLogger(logger_name)
11
+ logger.propagate = False
12
+ logger.setLevel(logging.DEBUG)
13
+ assert log_file_path is not None, "log_file_path is required"
14
+ fh = logging.FileHandler(log_file_path)
15
+ fh_formatter = logging.Formatter("%(asctime)s : %(levelname)s, %(funcName)s Message: %(message)s")
16
+ fh.setFormatter(fh_formatter)
17
+ logger.addHandler(fh)
18
+ logger.info(f"logging start: {logger_name}")
19
+ return logger
20
+
21
+
22
+ class Timer:
23
+ """
24
+ A simple timer class for measuring elapsed time.
25
+ """
26
+
27
+ def __init__(self, filename: os.PathLike = "timer_log.log", reset: bool = False):
28
+ """
29
+ Initialize the Timer object.
30
+ """
31
+ self.start_time = None
32
+ self.last_checkpoint = None
33
+ self.filename = filename
34
+ self.logger = create_logger("Timer", filename)
35
+ if reset:
36
+ self._reset_log_file()
37
+
38
+ def _reset_log_file(self):
39
+ """
40
+ Reset the log file by clearing its contents.
41
+ """
42
+ with open(self.filename, "w") as file:
43
+ file.write("")
44
+
45
+ def start(self):
46
+ """
47
+ Start the timer.
48
+ """
49
+ self.start_time = time.time()
50
+ self.last_checkpoint = self.start_time
51
+ self.logger.info("Timer started.")
52
+
53
+ def check(self, message):
54
+ """
55
+ Log a checkpoint with the current time and time since the last checkpoint.
56
+
57
+ Args:
58
+ message (str): The message to include in the log.
59
+ """
60
+ if self.start_time is None:
61
+ self.logger.warning("Timer has not been started.")
62
+ else:
63
+ log_message = (
64
+ f"Current time count: {time.time() - self.start_time:.4f} seconds, "
65
+ f"Time since last checkpoint: {time.time() - self.last_checkpoint:.4f} seconds, "
66
+ f"for {message}"
67
+ )
68
+ self.last_checkpoint = time.time()
69
+ self.logger.info(log_message)
70
+
71
+ def stop(self):
72
+ """
73
+ Stop the timer and log the elapsed time.
74
+ """
75
+ if self.start_time is None:
76
+ self.logger.warning("Timer has not been started.")
77
+ else:
78
+ self.end_time = time.time()
79
+ self.logger.info(f"Total elapsed time: {self.end_time - self.start_time} seconds\n")
80
+
81
+
82
+ if __name__ == "__main__":
83
+ # Test the Timer class
84
+ timer = Timer(filename="timer_log.log", reset=True)
85
+ timer.start()
86
+ timer.check("First checkpoint")
87
+ time.sleep(1)
88
+ timer.check("Second checkpoint")
89
+ timer.stop()