Spaces:
Running
on
Zero
Running
on
Zero
bol
commited on
Commit
·
99738e0
1
Parent(s):
e493783
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .gitattributes +1 -0
- .gitignore +29 -0
- README.md +8 -6
- app.py +163 -144
- app_old.py +374 -0
- image_datasets/.DS_Store +0 -0
- image_datasets/dataset.py +231 -0
- inference_configs/inference.yaml +33 -0
- requirements.txt +25 -5
- src/.DS_Store +0 -0
- src/flux/.DS_Store +0 -0
- src/flux/__init__.py +11 -0
- src/flux/__main__.py +4 -0
- src/flux/annotator/canny/__init__.py +6 -0
- src/flux/annotator/ckpts/ckpts.txt +1 -0
- src/flux/annotator/dwpose/__init__.py +68 -0
- src/flux/annotator/dwpose/onnxdet.py +125 -0
- src/flux/annotator/dwpose/onnxpose.py +360 -0
- src/flux/annotator/dwpose/util.py +297 -0
- src/flux/annotator/dwpose/wholebody.py +48 -0
- src/flux/annotator/hed/__init__.py +95 -0
- src/flux/annotator/midas/LICENSE +21 -0
- src/flux/annotator/midas/__init__.py +42 -0
- src/flux/annotator/midas/api.py +168 -0
- src/flux/annotator/midas/midas/__init__.py +0 -0
- src/flux/annotator/midas/midas/base_model.py +16 -0
- src/flux/annotator/midas/midas/blocks.py +342 -0
- src/flux/annotator/midas/midas/dpt_depth.py +109 -0
- src/flux/annotator/midas/midas/midas_net.py +76 -0
- src/flux/annotator/midas/midas/midas_net_custom.py +128 -0
- src/flux/annotator/midas/midas/transforms.py +234 -0
- src/flux/annotator/midas/midas/vit.py +491 -0
- src/flux/annotator/midas/utils.py +189 -0
- src/flux/annotator/mlsd/LICENSE +201 -0
- src/flux/annotator/mlsd/__init__.py +40 -0
- src/flux/annotator/mlsd/models/mbv2_mlsd_large.py +292 -0
- src/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py +275 -0
- src/flux/annotator/mlsd/utils.py +580 -0
- src/flux/annotator/tile/__init__.py +26 -0
- src/flux/annotator/tile/guided_filter.py +280 -0
- src/flux/annotator/util.py +38 -0
- src/flux/annotator/zoe/LICENSE +21 -0
- src/flux/annotator/zoe/__init__.py +48 -0
- src/flux/annotator/zoe/zoedepth/data/__init__.py +24 -0
- src/flux/annotator/zoe/zoedepth/data/data_mono.py +573 -0
- src/flux/annotator/zoe/zoedepth/data/ddad.py +117 -0
- src/flux/annotator/zoe/zoedepth/data/diml_indoor_test.py +125 -0
- src/flux/annotator/zoe/zoedepth/data/diml_outdoor_test.py +114 -0
- src/flux/annotator/zoe/zoedepth/data/diode.py +125 -0
.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/**/* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
assets/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
wheels/
|
24 |
+
pip-wheel-metadata/
|
25 |
+
share/python-wheels/
|
26 |
+
*.egg-info/
|
27 |
+
.installed.cfg
|
28 |
+
*.egg
|
29 |
+
MANIFEST
|
README.md
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
|
|
10 |
---
|
11 |
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: ByteMorph Demo
|
3 |
+
emoji: 📊
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.31.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: other
|
11 |
+
short_description: Online Demo for ByteMorph
|
12 |
---
|
13 |
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -1,154 +1,173 @@
|
|
1 |
import gradio as gr
|
2 |
-
import numpy as np
|
3 |
-
import random
|
4 |
-
|
5 |
-
# import spaces #[uncomment to use ZeroGPU]
|
6 |
-
from diffusers import DiffusionPipeline
|
7 |
import torch
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
#
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
#
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
"""
|
66 |
|
67 |
-
|
68 |
-
with gr.
|
69 |
-
gr.Markdown(
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
show_label=False,
|
75 |
-
max_lines=1,
|
76 |
-
placeholder="Enter your prompt",
|
77 |
-
container=False,
|
78 |
-
)
|
79 |
-
|
80 |
-
run_button = gr.Button("Run", scale=0, variant="primary")
|
81 |
-
|
82 |
-
result = gr.Image(label="Result", show_label=False)
|
83 |
-
|
84 |
-
with gr.Accordion("Advanced Settings", open=False):
|
85 |
-
negative_prompt = gr.Text(
|
86 |
-
label="Negative prompt",
|
87 |
-
max_lines=1,
|
88 |
-
placeholder="Enter a negative prompt",
|
89 |
-
visible=False,
|
90 |
-
)
|
91 |
-
|
92 |
-
seed = gr.Slider(
|
93 |
-
label="Seed",
|
94 |
-
minimum=0,
|
95 |
-
maximum=MAX_SEED,
|
96 |
-
step=1,
|
97 |
-
value=0,
|
98 |
-
)
|
99 |
-
|
100 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
101 |
-
|
102 |
-
with gr.Row():
|
103 |
-
width = gr.Slider(
|
104 |
-
label="Width",
|
105 |
-
minimum=256,
|
106 |
-
maximum=MAX_IMAGE_SIZE,
|
107 |
-
step=32,
|
108 |
-
value=1024, # Replace with defaults that work for your model
|
109 |
-
)
|
110 |
-
|
111 |
-
height = gr.Slider(
|
112 |
-
label="Height",
|
113 |
-
minimum=256,
|
114 |
-
maximum=MAX_IMAGE_SIZE,
|
115 |
-
step=32,
|
116 |
-
value=1024, # Replace with defaults that work for your model
|
117 |
)
|
|
|
|
|
118 |
|
119 |
-
with gr.
|
120 |
-
|
121 |
-
label="Guidance scale",
|
122 |
-
minimum=0.0,
|
123 |
-
maximum=10.0,
|
124 |
-
step=0.1,
|
125 |
-
value=0.0, # Replace with defaults that work for your model
|
126 |
-
)
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
)
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
],
|
150 |
-
outputs=[result, seed],
|
151 |
-
)
|
152 |
|
153 |
if __name__ == "__main__":
|
154 |
-
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
+
import spaces
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
from safetensors.torch import load_file
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from src.flux.util import load_ae, load_clip, load_flow_model2, load_t5, tensor_to_pil_image
|
12 |
+
from src.flux.xflux_pipeline import XFluxSampler
|
13 |
+
from image_datasets.dataset import image_resize
|
14 |
+
|
15 |
+
# ===== No CUDA/model initialization globally =====
|
16 |
+
args = OmegaConf.load("inference_configs/inference.yaml")
|
17 |
+
is_schnell = args.model_name == "flux-schnell"
|
18 |
+
|
19 |
+
# sampler = None
|
20 |
+
device = torch.device("cuda")
|
21 |
+
dtype = torch.bfloat16
|
22 |
+
dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype)
|
23 |
+
vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
|
24 |
+
t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype)
|
25 |
+
clip = load_clip("cpu").to(device, dtype=dtype)
|
26 |
+
|
27 |
+
vae.requires_grad_(False)
|
28 |
+
t5.requires_grad_(False)
|
29 |
+
clip.requires_grad_(False)
|
30 |
+
|
31 |
+
model_path = hf_hub_download(
|
32 |
+
repo_id="Boese0601/ByteMorpher",
|
33 |
+
filename="dit.safetensors",
|
34 |
+
use_auth_token=os.getenv("HF_TOKEN")
|
35 |
+
)
|
36 |
+
state_dict = load_file(model_path)
|
37 |
+
dit.load_state_dict(state_dict)
|
38 |
+
dit.eval()
|
39 |
+
dit.to(device, dtype=dtype)
|
40 |
+
|
41 |
+
sampler = XFluxSampler(
|
42 |
+
clip=clip,
|
43 |
+
t5=t5,
|
44 |
+
ae=vae,
|
45 |
+
model=dit,
|
46 |
+
device=device,
|
47 |
+
ip_loaded=False,
|
48 |
+
spatial_condition=False,
|
49 |
+
clip_image_processor=None,
|
50 |
+
image_encoder=None,
|
51 |
+
improj=None
|
52 |
+
)
|
53 |
+
#test push
|
54 |
+
@spaces.GPU
|
55 |
+
def generate(image: Image.Image, edit_prompt: str):
|
56 |
+
# global sampler
|
57 |
+
# device = torch.device("cuda")
|
58 |
+
# dtype = torch.bfloat16
|
59 |
+
|
60 |
+
# if sampler is None:
|
61 |
+
# dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype)
|
62 |
+
# vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
|
63 |
+
# t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype)
|
64 |
+
# clip = load_clip("cpu").to(device, dtype=dtype)
|
65 |
+
|
66 |
+
# vae.requires_grad_(False)
|
67 |
+
# t5.requires_grad_(False)
|
68 |
+
# clip.requires_grad_(False)
|
69 |
+
|
70 |
+
# model_path = hf_hub_download(
|
71 |
+
# repo_id="Boese0601/ByteMorpher",
|
72 |
+
# filename="dit.safetensors",
|
73 |
+
# use_auth_token=os.getenv("HF_TOKEN")
|
74 |
+
# )
|
75 |
+
# state_dict = load_file(model_path)
|
76 |
+
# dit.load_state_dict(state_dict)
|
77 |
+
# dit.eval()
|
78 |
+
|
79 |
+
# sampler = XFluxSampler(
|
80 |
+
# clip=clip,
|
81 |
+
# t5=t5,
|
82 |
+
# ae=vae,
|
83 |
+
# model=dit,
|
84 |
+
# device=device,
|
85 |
+
# ip_loaded=False,
|
86 |
+
# spatial_condition=False,
|
87 |
+
# clip_image_processor=None,
|
88 |
+
# image_encoder=None,
|
89 |
+
# improj=None
|
90 |
+
# )
|
91 |
+
|
92 |
+
img = image_resize(image, 512)
|
93 |
+
w, h = img.size
|
94 |
+
img = img.resize(((w // 32) * 32, (h // 32) * 32))
|
95 |
+
img = torch.from_numpy((np.array(img) / 127.5) - 1)
|
96 |
+
img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype)
|
97 |
+
|
98 |
+
with torch.no_grad():
|
99 |
+
result = sampler(
|
100 |
+
prompt=edit_prompt,
|
101 |
+
width=args.sample_width,
|
102 |
+
height=args.sample_height,
|
103 |
+
num_steps=args.sample_steps,
|
104 |
+
image_prompt=None,
|
105 |
+
true_gs=args.cfg_scale,
|
106 |
+
seed=args.seed,
|
107 |
+
ip_scale=args.ip_scale if args.use_ip else 1.0,
|
108 |
+
source_image=img if args.use_spatial_condition else None,
|
109 |
+
)
|
110 |
+
return tensor_to_pil_image(result)
|
111 |
+
|
112 |
+
def get_samples():
|
113 |
+
sample_list = [
|
114 |
+
{
|
115 |
+
"image": "assets/0_camera_zoom/20486354.png",
|
116 |
+
"edit_prompt": "Zoom in on the coral and add a small blue fish in the background.",
|
117 |
+
},
|
118 |
+
]
|
119 |
+
return [
|
120 |
+
[
|
121 |
+
Image.open(sample["image"]).resize((512, 512)),
|
122 |
+
sample["edit_prompt"],
|
123 |
+
]
|
124 |
+
for sample in sample_list
|
125 |
+
]
|
126 |
+
|
127 |
+
header = """
|
128 |
+
# ByteMorph
|
129 |
+
|
130 |
+
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
|
131 |
+
<a href=""><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
|
132 |
+
<a href="https://huggingface.co/datasets/Boese0601/ByteMorph-Bench"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
|
133 |
+
<a href="https://github.com/Boese0601/ByteMorph"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
|
134 |
+
</div>
|
135 |
"""
|
136 |
|
137 |
+
def create_app():
|
138 |
+
with gr.Blocks() as app:
|
139 |
+
gr.Markdown(header, elem_id="header")
|
140 |
+
with gr.Row(equal_height=False):
|
141 |
+
with gr.Column(variant="panel", elem_classes="inputPanel"):
|
142 |
+
original_image = gr.Image(
|
143 |
+
type="pil", label="Condition Image", width=300, elem_id="input"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
)
|
145 |
+
edit_prompt = gr.Textbox(lines=2, label="Edit Prompt", elem_id="edit_prompt")
|
146 |
+
submit_btn = gr.Button("Run", elem_id="submit_btn")
|
147 |
|
148 |
+
with gr.Column(variant="panel", elem_classes="outputPanel"):
|
149 |
+
output_image = gr.Image(type="pil", elem_id="output")
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
+
with gr.Row():
|
152 |
+
examples = gr.Examples(
|
153 |
+
examples=get_samples(),
|
154 |
+
inputs=[original_image, edit_prompt],
|
155 |
+
label="Examples",
|
156 |
+
)
|
|
|
157 |
|
158 |
+
submit_btn.click(
|
159 |
+
fn=generate,
|
160 |
+
inputs=[original_image, edit_prompt],
|
161 |
+
outputs=output_image,
|
162 |
+
)
|
163 |
+
gr.HTML(
|
164 |
+
"""
|
165 |
+
<div style="text-align: center;">
|
166 |
+
* This demo's template was modified from <a href="https://arxiv.org/abs/2411.15098" target="_blank">OminiControl</a>.
|
167 |
+
</div>
|
168 |
+
"""
|
169 |
+
)
|
170 |
+
return app
|
|
|
|
|
|
|
171 |
|
172 |
if __name__ == "__main__":
|
173 |
+
create_app().launch(debug=False, share=False, ssr_mode=False)
|
app_old.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import spaces
|
4 |
+
from PIL import Image, ImageDraw, ImageFont
|
5 |
+
# from src.condition import Condition
|
6 |
+
from diffusers.pipelines import FluxPipeline
|
7 |
+
import numpy as np
|
8 |
+
import requests
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
from safetensors.torch import load_file
|
11 |
+
import torch.multiprocessing as mp
|
12 |
+
###
|
13 |
+
import argparse
|
14 |
+
import logging
|
15 |
+
import math
|
16 |
+
import os
|
17 |
+
import re
|
18 |
+
import random
|
19 |
+
import shutil
|
20 |
+
from contextlib import nullcontext
|
21 |
+
from pathlib import Path
|
22 |
+
from PIL import Image
|
23 |
+
import accelerate
|
24 |
+
import datasets
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
import torch.nn.functional as F
|
28 |
+
from torch import Tensor, nn
|
29 |
+
import torch.utils.checkpoint
|
30 |
+
import transformers
|
31 |
+
from accelerate import Accelerator
|
32 |
+
from accelerate.logging import get_logger
|
33 |
+
from accelerate.state import AcceleratorState
|
34 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
35 |
+
from huggingface_hub import create_repo, upload_folder
|
36 |
+
from packaging import version
|
37 |
+
from tqdm.auto import tqdm
|
38 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
39 |
+
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
|
40 |
+
from transformers.utils import ContextManagers
|
41 |
+
from omegaconf import OmegaConf
|
42 |
+
from copy import deepcopy
|
43 |
+
import diffusers
|
44 |
+
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline
|
45 |
+
from diffusers.optimization import get_scheduler
|
46 |
+
from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr
|
47 |
+
from diffusers.utils import check_min_version, deprecate, make_image_grid
|
48 |
+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
49 |
+
from diffusers.utils.import_utils import is_xformers_available
|
50 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
51 |
+
from einops import rearrange
|
52 |
+
from src.flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
|
53 |
+
from src.flux.util import (configs, load_ae, load_clip,
|
54 |
+
load_flow_model2, load_t5, save_image, tensor_to_pil_image, load_checkpoint)
|
55 |
+
from src.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor, IPDoubleStreamBlockProcessor, IPSingleStreamBlockProcessor, ImageProjModel
|
56 |
+
from src.flux.xflux_pipeline import XFluxSampler
|
57 |
+
|
58 |
+
from image_datasets.dataset import loader, eval_image_pair_loader, image_resize
|
59 |
+
|
60 |
+
from safetensors.torch import load_file
|
61 |
+
import json
|
62 |
+
|
63 |
+
|
64 |
+
# logger = get_logger(__name__, log_level="INFO")
|
65 |
+
|
66 |
+
|
67 |
+
def get_models(name: str, device, offload: bool, is_schnell: bool):
|
68 |
+
t5 = load_t5(device, max_length=256 if is_schnell else 512)
|
69 |
+
clip = load_clip(device)
|
70 |
+
clip.requires_grad_(False)
|
71 |
+
model = load_flow_model2(name, device="cpu")
|
72 |
+
vae = load_ae(name, device="cpu" if offload else device)
|
73 |
+
return model, vae, t5, clip
|
74 |
+
|
75 |
+
args = OmegaConf.load("inference_configs/inference.yaml") #OmegaConf.load(parse_args())
|
76 |
+
is_schnell = args.model_name == "flux-schnell"
|
77 |
+
set_seed(args.seed)
|
78 |
+
# logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
79 |
+
device = "cuda"
|
80 |
+
dit, vae, t5, clip = get_models(name=args.model_name, device=device, offload=False, is_schnell=is_schnell)
|
81 |
+
|
82 |
+
# # load image encoder
|
83 |
+
# ip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.getenv("CLIP_VIT")).to(
|
84 |
+
# # accelerator.device, dtype=torch.bfloat16
|
85 |
+
# device, dtype=torch.bfloat16
|
86 |
+
# )
|
87 |
+
# ip_clip_image_processor = CLIPImageProcessor()
|
88 |
+
|
89 |
+
if args.use_ip:
|
90 |
+
sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=True, spatial_condition=False, clip_image_processor=ip_clip_image_processor, image_encoder=ip_image_encoder, improj=ip_improj)
|
91 |
+
elif args.use_spatial_condition:
|
92 |
+
sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=False, spatial_condition=True, clip_image_processor=None, image_encoder=None, improj=None,share_position_embedding=args.share_position_embedding)
|
93 |
+
else:
|
94 |
+
sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=False, spatial_condition=False, clip_image_processor=None, image_encoder=None, improj=None)
|
95 |
+
|
96 |
+
|
97 |
+
# @spaces.GPU
|
98 |
+
def generate(image, edit_prompt):
|
99 |
+
print("hello?????????!!!!!")
|
100 |
+
|
101 |
+
# accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
102 |
+
|
103 |
+
# accelerator = Accelerator(
|
104 |
+
# gradient_accumulation_steps=1,
|
105 |
+
# mixed_precision=args.mixed_precision,
|
106 |
+
# log_with=args.report_to,
|
107 |
+
# project_config=accelerator_project_config,
|
108 |
+
# )
|
109 |
+
|
110 |
+
# Make one log on every process with the configuration for debugging.
|
111 |
+
# logging.basicConfig(
|
112 |
+
# format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
113 |
+
# datefmt="%m/%d/%Y %H:%M:%S",
|
114 |
+
# level=logging.INFO,
|
115 |
+
# )
|
116 |
+
# logger.info(accelerator.state, main_process_only=False)
|
117 |
+
# if accelerator.is_local_main_process:
|
118 |
+
# datasets.utils.logging.set_verbosity_warning()
|
119 |
+
# transformers.utils.logging.set_verbosity_warning()
|
120 |
+
# diffusers.utils.logging.set_verbosity_info()
|
121 |
+
# else:
|
122 |
+
# datasets.utils.logging.set_verbosity_error()
|
123 |
+
# transformers.utils.logging.set_verbosity_error()
|
124 |
+
# diffusers.utils.logging.set_verbosity_error()
|
125 |
+
|
126 |
+
|
127 |
+
# if accelerator.is_main_process:
|
128 |
+
# if args.output_dir is not None:
|
129 |
+
# os.makedirs(args.output_dir, exist_ok=True)
|
130 |
+
# gpt_eval_path = os.path.join(args.output_dir,"Eval")
|
131 |
+
# os.makedirs(gpt_eval_path, exist_ok=True)
|
132 |
+
|
133 |
+
# dit, vae, t5, clip = get_models(name=args.model_name, device=accelerator.device, offload=False, is_schnell=is_schnell)
|
134 |
+
# dit, vae, t5, clip = get_models(name=args.model_name, device=device, offload=False, is_schnell=is_schnell)
|
135 |
+
|
136 |
+
if args.use_lora:
|
137 |
+
lora_attn_procs = {}
|
138 |
+
if args.use_ip:
|
139 |
+
ip_attn_procs = {}
|
140 |
+
if args.double_blocks is None:
|
141 |
+
double_blocks_idx = list(range(19))
|
142 |
+
else:
|
143 |
+
double_blocks_idx = [int(idx) for idx in args.double_blocks.split(",")]
|
144 |
+
|
145 |
+
if args.single_blocks is None:
|
146 |
+
single_blocks_idx = list(range(38))
|
147 |
+
elif args.single_blocks is not None:
|
148 |
+
single_blocks_idx = [int(idx) for idx in args.single_blocks.split(",")]
|
149 |
+
|
150 |
+
if args.use_lora:
|
151 |
+
for name, attn_processor in dit.attn_processors.items():
|
152 |
+
match = re.search(r'\.(\d+)\.', name)
|
153 |
+
if match:
|
154 |
+
layer_index = int(match.group(1))
|
155 |
+
|
156 |
+
if name.startswith("double_blocks") and layer_index in double_blocks_idx:
|
157 |
+
# if accelerator.is_main_process:
|
158 |
+
# print("setting LoRA Processor for", name)
|
159 |
+
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(
|
160 |
+
dim=3072, rank=args.rank
|
161 |
+
)
|
162 |
+
elif name.startswith("single_blocks") and layer_index in single_blocks_idx:
|
163 |
+
# if accelerator.is_main_process:
|
164 |
+
# print("setting LoRA Processor for", name)
|
165 |
+
lora_attn_procs[name] = SingleStreamBlockLoraProcessor(
|
166 |
+
dim=3072, rank=args.rank
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
lora_attn_procs[name] = attn_processor
|
170 |
+
|
171 |
+
dit.set_attn_processor(lora_attn_procs)
|
172 |
+
|
173 |
+
# if args.use_ip:
|
174 |
+
# # unpack checkpoint
|
175 |
+
# checkpoint = load_checkpoint(args.ip_local_path, args.ip_repo_id, args.ip_name)
|
176 |
+
# prefix = "double_blocks."
|
177 |
+
# # blocks = {}
|
178 |
+
# proj = {}
|
179 |
+
|
180 |
+
# for key, value in checkpoint.items():
|
181 |
+
# # if key.startswith(prefix):
|
182 |
+
# # blocks[key[len(prefix):].replace('.processor.', '.')] = value
|
183 |
+
# if key.startswith("ip_adapter_proj_model"):
|
184 |
+
# proj[key[len("ip_adapter_proj_model."):]] = value
|
185 |
+
|
186 |
+
# # # load image encoder
|
187 |
+
# # ip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.getenv("CLIP_VIT")).to(
|
188 |
+
# # # accelerator.device, dtype=torch.bfloat16
|
189 |
+
# # device, dtype=torch.bfloat16
|
190 |
+
# # )
|
191 |
+
# # ip_clip_image_processor = CLIPImageProcessor()
|
192 |
+
|
193 |
+
# # setup image embedding projection model
|
194 |
+
# ip_improj = ImageProjModel(4096, 768, 4)
|
195 |
+
# ip_improj.load_state_dict(proj)
|
196 |
+
# # ip_improj = ip_improj.to(accelerator.device, dtype=torch.bfloat16)
|
197 |
+
# ip_improj = ip_improj.to(device, dtype=torch.bfloat16)
|
198 |
+
|
199 |
+
# ip_attn_procs = {}
|
200 |
+
|
201 |
+
# for name, _ in dit.attn_processors.items():
|
202 |
+
# ip_state_dict = {}
|
203 |
+
# for k in checkpoint.keys():
|
204 |
+
# if name in k:
|
205 |
+
# ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k]
|
206 |
+
# if ip_state_dict:
|
207 |
+
# ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072)
|
208 |
+
# ip_attn_procs[name].load_state_dict(ip_state_dict)
|
209 |
+
# ip_attn_procs[name].to(accelerator.device, dtype=torch.bfloat16)
|
210 |
+
# else:
|
211 |
+
# ip_attn_procs[name] = dit.attn_processors[name]
|
212 |
+
# dit.set_attn_processor(ip_attn_procs)
|
213 |
+
|
214 |
+
|
215 |
+
vae.requires_grad_(False)
|
216 |
+
t5.requires_grad_(False)
|
217 |
+
clip.requires_grad_(False)
|
218 |
+
|
219 |
+
|
220 |
+
|
221 |
+
# weight_dtype = torch.float32
|
222 |
+
# if accelerator.mixed_precision == "fp16":
|
223 |
+
# weight_dtype = torch.float16
|
224 |
+
# args.mixed_precision = accelerator.mixed_precision
|
225 |
+
# elif accelerator.mixed_precision == "bf16":
|
226 |
+
# weight_dtype = torch.bfloat16
|
227 |
+
# args.mixed_precision = accelerator.mixed_precision
|
228 |
+
|
229 |
+
|
230 |
+
# print(f"Resuming from checkpoint {args.ckpt_dir}")
|
231 |
+
# dit_stat_dict = load_file(args.ckpt_dir)
|
232 |
+
# Get path from Hub
|
233 |
+
model_path = hf_hub_download(
|
234 |
+
repo_id="Boese0601/ByteMorpher",
|
235 |
+
filename="dit.safetensors"
|
236 |
+
)
|
237 |
+
state_dict = load_file(model_path)
|
238 |
+
dit.load_state_dict(state_dict)
|
239 |
+
dit = dit.to(weight_dtype)
|
240 |
+
dit.eval()
|
241 |
+
|
242 |
+
# test_dataloader = loader(**args.data_config)
|
243 |
+
test_dataloader = eval_image_pair_loader(**args.data_config)
|
244 |
+
|
245 |
+
|
246 |
+
|
247 |
+
# from deepspeed import initialize
|
248 |
+
dit = accelerator.prepare(dit)
|
249 |
+
|
250 |
+
# if accelerator.is_main_process:
|
251 |
+
# accelerator.init_trackers(args.tracker_project_name, {"test": None})
|
252 |
+
|
253 |
+
# logger.info("***** Running Evaluation *****")
|
254 |
+
# logger.info(f" Instantaneous batch size = {args.eval_batch_size}")
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
# progress_bar = tqdm(
|
259 |
+
# range(0, len(test_dataloader)),
|
260 |
+
# initial=0,
|
261 |
+
# desc="Steps",
|
262 |
+
# disable=not accelerator.is_local_main_process,
|
263 |
+
# )
|
264 |
+
|
265 |
+
# for step, batch in enumerate(test_dataloader):
|
266 |
+
# with accelerator.accumulate(dit):
|
267 |
+
# img, tgt_image, prompt, edit_prompt, img_name, edit_name = batch
|
268 |
+
img = image_resize(image, 512)
|
269 |
+
w, h = img.size
|
270 |
+
new_w = (w // 32) * 32
|
271 |
+
new_h = (h // 32) * 32
|
272 |
+
img = img.resize((new_w, new_h))
|
273 |
+
img = torch.from_numpy((np.array(img) / 127.5) - 1)
|
274 |
+
img = img.permute(2, 0, 1).unsqueeze(0)
|
275 |
+
|
276 |
+
edit_prompt = edit_prompt
|
277 |
+
|
278 |
+
# if args.use_ip:
|
279 |
+
# sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=True, spatial_condition=False, clip_image_processor=ip_clip_image_processor, image_encoder=ip_image_encoder, improj=ip_improj)
|
280 |
+
# elif args.use_spatial_condition:
|
281 |
+
# sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=False, spatial_condition=True, clip_image_processor=None, image_encoder=None, improj=None,share_position_embedding=args.share_position_embedding)
|
282 |
+
# else:
|
283 |
+
# sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=False, spatial_condition=False, clip_image_processor=None, image_encoder=None, improj=None)
|
284 |
+
with torch.no_grad():
|
285 |
+
result = sampler(prompt=edit_prompt,
|
286 |
+
width=args.sample_width,
|
287 |
+
height=args.sample_height,
|
288 |
+
num_steps=args.sample_steps,
|
289 |
+
image_prompt=None, # ip_adapter
|
290 |
+
true_gs=args.cfg_scale,
|
291 |
+
seed=args.seed,
|
292 |
+
ip_scale=args.ip_scale if args.use_ip else 1.0,
|
293 |
+
source_image=img if args.use_spatial_condition else None,
|
294 |
+
)
|
295 |
+
gen_img = result
|
296 |
+
|
297 |
+
|
298 |
+
|
299 |
+
# progress_bar.update(1)
|
300 |
+
|
301 |
+
# accelerator.wait_for_everyone()
|
302 |
+
# accelerator.end_training()
|
303 |
+
return gen_img
|
304 |
+
|
305 |
+
|
306 |
+
def get_samples():
|
307 |
+
sample_list = [
|
308 |
+
{
|
309 |
+
"image": "assets/0_camera_zoom/20486354.png",
|
310 |
+
"edit_prompt": "Zoom in on the coral and add a small blue fish in the background.",
|
311 |
+
},
|
312 |
+
]
|
313 |
+
return [
|
314 |
+
[
|
315 |
+
Image.open(sample["image"]).resize((512, 512)),
|
316 |
+
sample["edit_prompt"],
|
317 |
+
]
|
318 |
+
for sample in sample_list
|
319 |
+
]
|
320 |
+
|
321 |
+
|
322 |
+
header = """
|
323 |
+
# ByteMoprh
|
324 |
+
|
325 |
+
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
|
326 |
+
<a href=""><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
|
327 |
+
<a href="https://huggingface.co/datasets/Boese0601/ByteMorph-Bench"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
|
328 |
+
<a href="https://github.com/Boese0601/ByteMorph"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
|
329 |
+
</div>
|
330 |
+
"""
|
331 |
+
|
332 |
+
|
333 |
+
def create_app():
|
334 |
+
with gr.Blocks() as app:
|
335 |
+
gr.Markdown(header, elem_id="header")
|
336 |
+
with gr.Row(equal_height=False):
|
337 |
+
with gr.Column(variant="panel", elem_classes="inputPanel"):
|
338 |
+
original_image = gr.Image(
|
339 |
+
type="pil", label="Condition Image", width=300, elem_id="input"
|
340 |
+
)
|
341 |
+
edit_prompt = gr.Textbox(lines=2, label="Edit Prompt", elem_id="edit_prompt")
|
342 |
+
submit_btn = gr.Button("Run", elem_id="submit_btn")
|
343 |
+
|
344 |
+
with gr.Column(variant="panel", elem_classes="outputPanel"):
|
345 |
+
output_image = gr.Image(type="pil", elem_id="output")
|
346 |
+
|
347 |
+
with gr.Row():
|
348 |
+
examples = gr.Examples(
|
349 |
+
examples=get_samples(),
|
350 |
+
inputs=[original_image, edit_prompt],
|
351 |
+
label="Examples",
|
352 |
+
)
|
353 |
+
|
354 |
+
submit_btn.click(
|
355 |
+
fn=generate,
|
356 |
+
inputs=[original_image, edit_prompt],
|
357 |
+
outputs=output_image,
|
358 |
+
)
|
359 |
+
gr.HTML(
|
360 |
+
"""
|
361 |
+
<div style="text-align: center;">
|
362 |
+
* This demo's template was modified from <a href="https://arxiv.org/abs/2411.15098" target="_blank">OminiControl</a>.
|
363 |
+
</div>
|
364 |
+
"""
|
365 |
+
)
|
366 |
+
return app
|
367 |
+
|
368 |
+
|
369 |
+
if __name__ == "__main__":
|
370 |
+
print("CUDA available:", torch.cuda.is_available())
|
371 |
+
print("CUDA version:", torch.version.cuda)
|
372 |
+
print("GPU device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")
|
373 |
+
# mp.set_start_method("spawn", force=True)
|
374 |
+
create_app().launch(debug=False, share=True, ssr_mode=False)
|
image_datasets/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
image_datasets/dataset.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset, DataLoader
|
7 |
+
import json
|
8 |
+
import random
|
9 |
+
import glob
|
10 |
+
import torch
|
11 |
+
import torchvision.transforms.functional as TF
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
def image_resize(img, max_size=512):
|
16 |
+
w, h = img.size
|
17 |
+
if w >= h:
|
18 |
+
new_w = max_size
|
19 |
+
new_h = int((max_size / w) * h)
|
20 |
+
else:
|
21 |
+
new_h = max_size
|
22 |
+
new_w = int((max_size / h) * w)
|
23 |
+
return img.resize((new_w, new_h))
|
24 |
+
|
25 |
+
def c_crop(image):
|
26 |
+
width, height = image.size
|
27 |
+
new_size = min(width, height)
|
28 |
+
left = (width - new_size) / 2
|
29 |
+
top = (height - new_size) / 2
|
30 |
+
right = (width + new_size) / 2
|
31 |
+
bottom = (height + new_size) / 2
|
32 |
+
return image.crop((left, top, right, bottom))
|
33 |
+
|
34 |
+
def crop_to_aspect_ratio(image, ratio="16:9"):
|
35 |
+
width, height = image.size
|
36 |
+
ratio_map = {
|
37 |
+
"16:9": (16, 9),
|
38 |
+
"4:3": (4, 3),
|
39 |
+
"1:1": (1, 1)
|
40 |
+
}
|
41 |
+
target_w, target_h = ratio_map[ratio]
|
42 |
+
target_ratio_value = target_w / target_h
|
43 |
+
|
44 |
+
current_ratio = width / height
|
45 |
+
|
46 |
+
if current_ratio > target_ratio_value:
|
47 |
+
new_width = int(height * target_ratio_value)
|
48 |
+
offset = (width - new_width) // 2
|
49 |
+
crop_box = (offset, 0, offset + new_width, height)
|
50 |
+
else:
|
51 |
+
new_height = int(width / target_ratio_value)
|
52 |
+
offset = (height - new_height) // 2
|
53 |
+
crop_box = (0, offset, width, offset + new_height)
|
54 |
+
|
55 |
+
cropped_img = image.crop(crop_box)
|
56 |
+
return cropped_img
|
57 |
+
|
58 |
+
|
59 |
+
class CustomImageDataset(Dataset):
|
60 |
+
def __init__(self, img_dir, img_size=512, caption_type='json', random_ratio=False):
|
61 |
+
self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
|
62 |
+
# self.images = glob.glob(img_dir +'**/*.jpg', recursive=True) + glob.glob(img_dir +'**/*.png', recursive=True) + glob.glob(img_dir +'**/*.jpeg', recursive=True)
|
63 |
+
self.images.sort()
|
64 |
+
self.img_size = img_size
|
65 |
+
self.caption_type = caption_type
|
66 |
+
self.random_ratio = random_ratio
|
67 |
+
|
68 |
+
def __len__(self):
|
69 |
+
return len(self.images)
|
70 |
+
|
71 |
+
def __getitem__(self, idx):
|
72 |
+
try:
|
73 |
+
img = Image.open(self.images[idx]).convert('RGB')
|
74 |
+
|
75 |
+
if self.random_ratio:
|
76 |
+
ratio = random.choice(["16:9", "default", "1:1", "4:3"])
|
77 |
+
if ratio != "default":
|
78 |
+
img = crop_to_aspect_ratio(img, ratio)
|
79 |
+
img = image_resize(img, self.img_size)
|
80 |
+
w, h = img.size
|
81 |
+
new_w = (w // 32) * 32
|
82 |
+
new_h = (h // 32) * 32
|
83 |
+
img = img.resize((new_w, new_h))
|
84 |
+
img = torch.from_numpy((np.array(img) / 127.5) - 1)
|
85 |
+
img = img.permute(2, 0, 1)
|
86 |
+
json_path = self.images[idx].split('.')[0] + '.' + self.caption_type
|
87 |
+
if self.caption_type == "json":
|
88 |
+
prompt = json.load(open(json_path))['caption']
|
89 |
+
else:
|
90 |
+
prompt = open(json_path).read()
|
91 |
+
return img, prompt
|
92 |
+
except Exception as e:
|
93 |
+
print(e)
|
94 |
+
return self.__getitem__(random.randint(0, len(self.images) - 1))
|
95 |
+
|
96 |
+
|
97 |
+
def loader(train_batch_size, num_workers, **args):
|
98 |
+
dataset = CustomImageDataset(**args)
|
99 |
+
return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True)
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
class ImageEditPairDataset(Dataset):
|
104 |
+
def __init__(self, img_dir, img_size=512, caption_type='json', random_ratio=False, grayscale_editing=False, zoom_camera=False):
|
105 |
+
# self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
|
106 |
+
self.images = glob.glob(img_dir +'**/*.jpg', recursive=True) + glob.glob(img_dir +'**/*.png', recursive=True) + glob.glob(img_dir +'**/*.jpeg', recursive=True)
|
107 |
+
self.images.sort()
|
108 |
+
self.img_size = img_size
|
109 |
+
self.caption_type = caption_type
|
110 |
+
self.random_ratio = random_ratio
|
111 |
+
self.grayscale_editing = grayscale_editing
|
112 |
+
self.zoom_camera = zoom_camera
|
113 |
+
if "ByteMorph-Bench" or "InstructMove" in img_dir:
|
114 |
+
self.eval = True
|
115 |
+
else:
|
116 |
+
self.eval = False
|
117 |
+
def __len__(self):
|
118 |
+
return len(self.images)
|
119 |
+
|
120 |
+
def __getitem__(self, idx):
|
121 |
+
try:
|
122 |
+
img = Image.open(self.images[idx]).convert('RGB')
|
123 |
+
ori_width, ori_height = img.size
|
124 |
+
left_half = (0, 0, ori_width // 2, ori_height)
|
125 |
+
right_half = (ori_width // 2, 0, ori_width, ori_height)
|
126 |
+
src_image = img.crop(left_half) # Left half
|
127 |
+
tgt_image = img.crop(right_half) # Right half
|
128 |
+
# print("ori_width, ori_height: ",ori_width, ori_height)
|
129 |
+
if self.random_ratio:
|
130 |
+
ratio = random.choice(["16:9", "default", "1:1", "4:3"])
|
131 |
+
if ratio != "default":
|
132 |
+
src_image = crop_to_aspect_ratio(src_image, ratio)
|
133 |
+
tgt_image = crop_to_aspect_ratio(tgt_image, ratio)
|
134 |
+
src_image = image_resize(src_image, self.img_size)
|
135 |
+
tgt_image = image_resize(tgt_image, self.img_size)
|
136 |
+
w, h = src_image.size
|
137 |
+
new_w = (w // 32) * 32
|
138 |
+
new_h = (h // 32) * 32
|
139 |
+
# print("new_w, new_h: ",new_w, new_h)
|
140 |
+
src_image = src_image.resize((new_w, new_h))
|
141 |
+
src_image = torch.from_numpy((np.array(src_image) / 127.5) - 1)
|
142 |
+
src_image = src_image.permute(2, 0, 1)
|
143 |
+
tgt_image = tgt_image.resize((new_w, new_h))
|
144 |
+
tgt_image = torch.from_numpy((np.array(tgt_image) / 127.5) - 1)
|
145 |
+
tgt_image = tgt_image.permute(2, 0, 1)
|
146 |
+
json_path = self.images[idx].split('.')[0] + '.' + self.caption_type
|
147 |
+
if self.eval:
|
148 |
+
image_name = self.images[idx].split('.')[0].split("/")[-1]
|
149 |
+
edit_type = self.images[idx].split('.')[0].split("/")[-2]
|
150 |
+
if self.caption_type == "json":
|
151 |
+
if not self.eval:
|
152 |
+
prompt = json.load(open(json_path))['caption']
|
153 |
+
edit_prompt = json.load(open(json_path))['edit']
|
154 |
+
else:
|
155 |
+
prompt = [] #json.load(open(json_path))['caption']
|
156 |
+
edit_prompt = json.load(open(json_path))['edit']
|
157 |
+
else:
|
158 |
+
raise NotImplementedError
|
159 |
+
# prompt = open(json_path).read()
|
160 |
+
if (not self.grayscale_editing) and (not self.zoom_camera):
|
161 |
+
if not self.eval:
|
162 |
+
return src_image, tgt_image, prompt, edit_prompt
|
163 |
+
else:
|
164 |
+
return src_image, tgt_image, prompt, edit_prompt, image_name, edit_type
|
165 |
+
if self.grayscale_editing and (not self.zoom_camera):
|
166 |
+
# Grayscale = 0.2989 * R + 0.5870 * G + 0.1140 * B
|
167 |
+
grayscale_image = 0.2989 * src_image[0, :, :] + 0.5870 * src_image[1, :, :] + 0.1140 * src_image[2, :, :]
|
168 |
+
tgt_image = grayscale_image.unsqueeze(0).repeat(3, 1, 1)
|
169 |
+
edit_prompt = "Convert the input image to a black and white grayscale image while maintaining the original composition and details."
|
170 |
+
if not self.eval:
|
171 |
+
return src_image, tgt_image, prompt, edit_prompt
|
172 |
+
else:
|
173 |
+
return src_image, tgt_image, prompt, edit_prompt, image_name, edit_type
|
174 |
+
if (not self.grayscale_editing) and self.zoom_camera:
|
175 |
+
cropped = TF.center_crop(src_image, (256, 256))
|
176 |
+
tgt_image = TF.resize(cropped, (512, 512))
|
177 |
+
edit_prompt = "The central area of the input image is zoomed. The camera transitions from a wide shot to a closer position, narrowing its view."
|
178 |
+
if not self.eval:
|
179 |
+
return src_image, tgt_image, prompt, edit_prompt
|
180 |
+
else:
|
181 |
+
return src_image, tgt_image, prompt, edit_prompt, image_name, edit_type
|
182 |
+
if self.grayscale_editing and self.zoom_camera:
|
183 |
+
grayscale_image = 0.2989 * src_image[0, :, :] + 0.5870 * src_image[1, :, :] + 0.1140 * src_image[2, :, :]
|
184 |
+
tgt_image = grayscale_image.unsqueeze(0).repeat(3, 1, 1)
|
185 |
+
tgt_image = TF.center_crop(tgt_image, (256, 256))
|
186 |
+
tgt_image = TF.resize(tgt_image, (512, 512))
|
187 |
+
edit_prompt = "Convert the input image to a black and white grayscale image while maintaining the original composition and details. And the central area of the input image is zoomed, the camera transitions from a wide shot to a closer position, narrowing its view."
|
188 |
+
if not self.eval:
|
189 |
+
return src_image, tgt_image, prompt, edit_prompt
|
190 |
+
else:
|
191 |
+
return src_image, tgt_image, prompt, edit_prompt, image_name, edit_type
|
192 |
+
except Exception as e:
|
193 |
+
print(e)
|
194 |
+
return self.__getitem__(random.randint(0, len(self.images) - 1))
|
195 |
+
|
196 |
+
|
197 |
+
def image_pair_loader(train_batch_size, num_workers, **args):
|
198 |
+
dataset = ImageEditPairDataset(**args)
|
199 |
+
return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True)
|
200 |
+
|
201 |
+
def eval_image_pair_loader(eval_batch_size, num_workers, **args):
|
202 |
+
dataset = ImageEditPairDataset(**args)
|
203 |
+
return DataLoader(dataset, batch_size=eval_batch_size, num_workers=num_workers, shuffle=False)
|
204 |
+
|
205 |
+
|
206 |
+
|
207 |
+
if __name__ == "__main__":
|
208 |
+
from src.flux.util import save_image
|
209 |
+
example_dataset = ImageEditPairDataset(
|
210 |
+
img_dir="",
|
211 |
+
img_size=512,
|
212 |
+
caption_type='json',
|
213 |
+
random_ratio=False,
|
214 |
+
grayscale_editing=False,
|
215 |
+
zoom_camera=False,
|
216 |
+
)
|
217 |
+
|
218 |
+
train_dataloader = DataLoader(
|
219 |
+
example_dataset,
|
220 |
+
batch_size=1,
|
221 |
+
num_workers=4,
|
222 |
+
shuffle=False,
|
223 |
+
)
|
224 |
+
|
225 |
+
for step, batch in enumerate(train_dataloader):
|
226 |
+
src_image, tgt_image, prompt, edit_prompt = batch
|
227 |
+
os.makedirs("./debug", exist_ok=True)
|
228 |
+
save_image(src_image, f"./debug/{step}-src_img.jpg")
|
229 |
+
save_image(tgt_image, f"./debug/{step}-tgt_img.jpg")
|
230 |
+
if step == 3:
|
231 |
+
breakpoint()
|
inference_configs/inference.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "flux-dev"
|
2 |
+
use_spatial_condition: true
|
3 |
+
share_position_embedding: true
|
4 |
+
use_share_weight_referencenet: false
|
5 |
+
use_ip: false
|
6 |
+
ip_local_path: null
|
7 |
+
ip_repo_id: null
|
8 |
+
ip_name: null
|
9 |
+
ip_scale: 1.0
|
10 |
+
use_lora: false
|
11 |
+
data_config:
|
12 |
+
eval_batch_size: 1
|
13 |
+
num_workers: 0
|
14 |
+
img_size: 512
|
15 |
+
img_dir: output_bench/ #./ByteMorph-Bench/
|
16 |
+
grayscale_editing: false
|
17 |
+
zoom_camera: false
|
18 |
+
random_ratio: false
|
19 |
+
report_to: wandb
|
20 |
+
eval_batch_size: 1
|
21 |
+
ckpt_dir: ./pretrained_weights/ByteMorpher/dit.safetensors
|
22 |
+
output_dir: ./test_log/seedmorpher/
|
23 |
+
logging_dir: logs
|
24 |
+
mixed_precision: "bf16"
|
25 |
+
rank: 16
|
26 |
+
single_blocks: null
|
27 |
+
double_blocks: null
|
28 |
+
disable_sampling: false
|
29 |
+
sample_width: 512
|
30 |
+
sample_height: 512
|
31 |
+
sample_steps: 25
|
32 |
+
seed: 42
|
33 |
+
cfg_scale: 3.5
|
requirements.txt
CHANGED
@@ -1,6 +1,26 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
diffusers
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --extra-index-url https://download.pytorch.org/whl/cu124
|
2 |
+
# torch==2.6.0
|
3 |
+
# torchvision==0.21.0
|
4 |
+
# torchaudio==2.6.0
|
5 |
+
|
6 |
+
gradio>=4.0
|
7 |
+
accelerate==0.30.1
|
8 |
+
deepspeed==0.14.4
|
9 |
+
einops==0.8.0
|
10 |
+
transformers==4.43.3
|
11 |
+
huggingface-hub==0.24.5
|
12 |
+
optimum-quanto
|
13 |
+
datasets
|
14 |
+
omegaconf
|
15 |
diffusers
|
16 |
+
sentencepiece
|
17 |
+
opencv-python
|
18 |
+
matplotlib
|
19 |
+
onnxruntime
|
20 |
+
timm
|
21 |
+
wandb
|
22 |
+
|
23 |
+
setuptools
|
24 |
+
wheel
|
25 |
+
|
26 |
+
|
src/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
src/flux/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
src/flux/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
try:
|
2 |
+
from ._version import version as __version__ # type: ignore
|
3 |
+
from ._version import version_tuple
|
4 |
+
except ImportError:
|
5 |
+
__version__ = "unknown (no version information available)"
|
6 |
+
version_tuple = (0, 0, "unknown", "noinfo")
|
7 |
+
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
PACKAGE = __package__.replace("_", "-")
|
11 |
+
PACKAGE_ROOT = Path(__file__).parent
|
src/flux/__main__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .cli import app
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
app()
|
src/flux/annotator/canny/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
|
3 |
+
|
4 |
+
class CannyDetector:
|
5 |
+
def __call__(self, img, low_threshold, high_threshold):
|
6 |
+
return cv2.Canny(img, low_threshold, high_threshold)
|
src/flux/annotator/ckpts/ckpts.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Weights here.
|
src/flux/annotator/dwpose/__init__.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Openpose
|
2 |
+
# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
|
3 |
+
# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
|
4 |
+
# 3rd Edited by ControlNet
|
5 |
+
# 4th Edited by ControlNet (added face and correct hands)
|
6 |
+
|
7 |
+
import os
|
8 |
+
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
from . import util
|
13 |
+
from .wholebody import Wholebody
|
14 |
+
|
15 |
+
def draw_pose(pose, H, W):
|
16 |
+
bodies = pose['bodies']
|
17 |
+
faces = pose['faces']
|
18 |
+
hands = pose['hands']
|
19 |
+
candidate = bodies['candidate']
|
20 |
+
subset = bodies['subset']
|
21 |
+
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
22 |
+
|
23 |
+
canvas = util.draw_bodypose(canvas, candidate, subset)
|
24 |
+
|
25 |
+
canvas = util.draw_handpose(canvas, hands)
|
26 |
+
|
27 |
+
canvas = util.draw_facepose(canvas, faces)
|
28 |
+
|
29 |
+
return canvas
|
30 |
+
|
31 |
+
|
32 |
+
class DWposeDetector:
|
33 |
+
def __init__(self, device):
|
34 |
+
|
35 |
+
self.pose_estimation = Wholebody(device)
|
36 |
+
|
37 |
+
def __call__(self, oriImg):
|
38 |
+
oriImg = oriImg.copy()
|
39 |
+
H, W, C = oriImg.shape
|
40 |
+
with torch.no_grad():
|
41 |
+
candidate, subset = self.pose_estimation(oriImg)
|
42 |
+
nums, keys, locs = candidate.shape
|
43 |
+
candidate[..., 0] /= float(W)
|
44 |
+
candidate[..., 1] /= float(H)
|
45 |
+
body = candidate[:,:18].copy()
|
46 |
+
body = body.reshape(nums*18, locs)
|
47 |
+
score = subset[:,:18]
|
48 |
+
for i in range(len(score)):
|
49 |
+
for j in range(len(score[i])):
|
50 |
+
if score[i][j] > 0.3:
|
51 |
+
score[i][j] = int(18*i+j)
|
52 |
+
else:
|
53 |
+
score[i][j] = -1
|
54 |
+
|
55 |
+
un_visible = subset<0.3
|
56 |
+
candidate[un_visible] = -1
|
57 |
+
|
58 |
+
foot = candidate[:,18:24]
|
59 |
+
|
60 |
+
faces = candidate[:,24:92]
|
61 |
+
|
62 |
+
hands = candidate[:,92:113]
|
63 |
+
hands = np.vstack([hands, candidate[:,113:]])
|
64 |
+
|
65 |
+
bodies = dict(candidate=body, subset=score)
|
66 |
+
pose = dict(bodies=bodies, hands=hands, faces=faces)
|
67 |
+
|
68 |
+
return draw_pose(pose, H, W)
|
src/flux/annotator/dwpose/onnxdet.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import onnxruntime
|
5 |
+
|
6 |
+
def nms(boxes, scores, nms_thr):
|
7 |
+
"""Single class NMS implemented in Numpy."""
|
8 |
+
x1 = boxes[:, 0]
|
9 |
+
y1 = boxes[:, 1]
|
10 |
+
x2 = boxes[:, 2]
|
11 |
+
y2 = boxes[:, 3]
|
12 |
+
|
13 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
14 |
+
order = scores.argsort()[::-1]
|
15 |
+
|
16 |
+
keep = []
|
17 |
+
while order.size > 0:
|
18 |
+
i = order[0]
|
19 |
+
keep.append(i)
|
20 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
21 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
22 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
23 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
24 |
+
|
25 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
26 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
27 |
+
inter = w * h
|
28 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
29 |
+
|
30 |
+
inds = np.where(ovr <= nms_thr)[0]
|
31 |
+
order = order[inds + 1]
|
32 |
+
|
33 |
+
return keep
|
34 |
+
|
35 |
+
def multiclass_nms(boxes, scores, nms_thr, score_thr):
|
36 |
+
"""Multiclass NMS implemented in Numpy. Class-aware version."""
|
37 |
+
final_dets = []
|
38 |
+
num_classes = scores.shape[1]
|
39 |
+
for cls_ind in range(num_classes):
|
40 |
+
cls_scores = scores[:, cls_ind]
|
41 |
+
valid_score_mask = cls_scores > score_thr
|
42 |
+
if valid_score_mask.sum() == 0:
|
43 |
+
continue
|
44 |
+
else:
|
45 |
+
valid_scores = cls_scores[valid_score_mask]
|
46 |
+
valid_boxes = boxes[valid_score_mask]
|
47 |
+
keep = nms(valid_boxes, valid_scores, nms_thr)
|
48 |
+
if len(keep) > 0:
|
49 |
+
cls_inds = np.ones((len(keep), 1)) * cls_ind
|
50 |
+
dets = np.concatenate(
|
51 |
+
[valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
|
52 |
+
)
|
53 |
+
final_dets.append(dets)
|
54 |
+
if len(final_dets) == 0:
|
55 |
+
return None
|
56 |
+
return np.concatenate(final_dets, 0)
|
57 |
+
|
58 |
+
def demo_postprocess(outputs, img_size, p6=False):
|
59 |
+
grids = []
|
60 |
+
expanded_strides = []
|
61 |
+
strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
|
62 |
+
|
63 |
+
hsizes = [img_size[0] // stride for stride in strides]
|
64 |
+
wsizes = [img_size[1] // stride for stride in strides]
|
65 |
+
|
66 |
+
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
|
67 |
+
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
|
68 |
+
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
|
69 |
+
grids.append(grid)
|
70 |
+
shape = grid.shape[:2]
|
71 |
+
expanded_strides.append(np.full((*shape, 1), stride))
|
72 |
+
|
73 |
+
grids = np.concatenate(grids, 1)
|
74 |
+
expanded_strides = np.concatenate(expanded_strides, 1)
|
75 |
+
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
|
76 |
+
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
|
77 |
+
|
78 |
+
return outputs
|
79 |
+
|
80 |
+
def preprocess(img, input_size, swap=(2, 0, 1)):
|
81 |
+
if len(img.shape) == 3:
|
82 |
+
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
|
83 |
+
else:
|
84 |
+
padded_img = np.ones(input_size, dtype=np.uint8) * 114
|
85 |
+
|
86 |
+
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
|
87 |
+
resized_img = cv2.resize(
|
88 |
+
img,
|
89 |
+
(int(img.shape[1] * r), int(img.shape[0] * r)),
|
90 |
+
interpolation=cv2.INTER_LINEAR,
|
91 |
+
).astype(np.uint8)
|
92 |
+
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
|
93 |
+
|
94 |
+
padded_img = padded_img.transpose(swap)
|
95 |
+
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
|
96 |
+
return padded_img, r
|
97 |
+
|
98 |
+
def inference_detector(session, oriImg):
|
99 |
+
input_shape = (640,640)
|
100 |
+
img, ratio = preprocess(oriImg, input_shape)
|
101 |
+
|
102 |
+
ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
|
103 |
+
output = session.run(None, ort_inputs)
|
104 |
+
predictions = demo_postprocess(output[0], input_shape)[0]
|
105 |
+
|
106 |
+
boxes = predictions[:, :4]
|
107 |
+
scores = predictions[:, 4:5] * predictions[:, 5:]
|
108 |
+
|
109 |
+
boxes_xyxy = np.ones_like(boxes)
|
110 |
+
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
|
111 |
+
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
|
112 |
+
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
|
113 |
+
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
|
114 |
+
boxes_xyxy /= ratio
|
115 |
+
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
|
116 |
+
if dets is not None:
|
117 |
+
final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
|
118 |
+
isscore = final_scores>0.3
|
119 |
+
iscat = final_cls_inds == 0
|
120 |
+
isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
|
121 |
+
final_boxes = final_boxes[isbbox]
|
122 |
+
else:
|
123 |
+
final_boxes = np.array([])
|
124 |
+
|
125 |
+
return final_boxes
|
src/flux/annotator/dwpose/onnxpose.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import onnxruntime as ort
|
6 |
+
|
7 |
+
def preprocess(
|
8 |
+
img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
|
9 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
10 |
+
"""Do preprocessing for RTMPose model inference.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
img (np.ndarray): Input image in shape.
|
14 |
+
input_size (tuple): Input image size in shape (w, h).
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
tuple:
|
18 |
+
- resized_img (np.ndarray): Preprocessed image.
|
19 |
+
- center (np.ndarray): Center of image.
|
20 |
+
- scale (np.ndarray): Scale of image.
|
21 |
+
"""
|
22 |
+
# get shape of image
|
23 |
+
img_shape = img.shape[:2]
|
24 |
+
out_img, out_center, out_scale = [], [], []
|
25 |
+
if len(out_bbox) == 0:
|
26 |
+
out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
|
27 |
+
for i in range(len(out_bbox)):
|
28 |
+
x0 = out_bbox[i][0]
|
29 |
+
y0 = out_bbox[i][1]
|
30 |
+
x1 = out_bbox[i][2]
|
31 |
+
y1 = out_bbox[i][3]
|
32 |
+
bbox = np.array([x0, y0, x1, y1])
|
33 |
+
|
34 |
+
# get center and scale
|
35 |
+
center, scale = bbox_xyxy2cs(bbox, padding=1.25)
|
36 |
+
|
37 |
+
# do affine transformation
|
38 |
+
resized_img, scale = top_down_affine(input_size, scale, center, img)
|
39 |
+
|
40 |
+
# normalize image
|
41 |
+
mean = np.array([123.675, 116.28, 103.53])
|
42 |
+
std = np.array([58.395, 57.12, 57.375])
|
43 |
+
resized_img = (resized_img - mean) / std
|
44 |
+
|
45 |
+
out_img.append(resized_img)
|
46 |
+
out_center.append(center)
|
47 |
+
out_scale.append(scale)
|
48 |
+
|
49 |
+
return out_img, out_center, out_scale
|
50 |
+
|
51 |
+
|
52 |
+
def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
|
53 |
+
"""Inference RTMPose model.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
sess (ort.InferenceSession): ONNXRuntime session.
|
57 |
+
img (np.ndarray): Input image in shape.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
outputs (np.ndarray): Output of RTMPose model.
|
61 |
+
"""
|
62 |
+
all_out = []
|
63 |
+
# build input
|
64 |
+
for i in range(len(img)):
|
65 |
+
input = [img[i].transpose(2, 0, 1)]
|
66 |
+
|
67 |
+
# build output
|
68 |
+
sess_input = {sess.get_inputs()[0].name: input}
|
69 |
+
sess_output = []
|
70 |
+
for out in sess.get_outputs():
|
71 |
+
sess_output.append(out.name)
|
72 |
+
|
73 |
+
# run model
|
74 |
+
outputs = sess.run(sess_output, sess_input)
|
75 |
+
all_out.append(outputs)
|
76 |
+
|
77 |
+
return all_out
|
78 |
+
|
79 |
+
|
80 |
+
def postprocess(outputs: List[np.ndarray],
|
81 |
+
model_input_size: Tuple[int, int],
|
82 |
+
center: Tuple[int, int],
|
83 |
+
scale: Tuple[int, int],
|
84 |
+
simcc_split_ratio: float = 2.0
|
85 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
86 |
+
"""Postprocess for RTMPose model output.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
outputs (np.ndarray): Output of RTMPose model.
|
90 |
+
model_input_size (tuple): RTMPose model Input image size.
|
91 |
+
center (tuple): Center of bbox in shape (x, y).
|
92 |
+
scale (tuple): Scale of bbox in shape (w, h).
|
93 |
+
simcc_split_ratio (float): Split ratio of simcc.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
tuple:
|
97 |
+
- keypoints (np.ndarray): Rescaled keypoints.
|
98 |
+
- scores (np.ndarray): Model predict scores.
|
99 |
+
"""
|
100 |
+
all_key = []
|
101 |
+
all_score = []
|
102 |
+
for i in range(len(outputs)):
|
103 |
+
# use simcc to decode
|
104 |
+
simcc_x, simcc_y = outputs[i]
|
105 |
+
keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
|
106 |
+
|
107 |
+
# rescale keypoints
|
108 |
+
keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
|
109 |
+
all_key.append(keypoints[0])
|
110 |
+
all_score.append(scores[0])
|
111 |
+
|
112 |
+
return np.array(all_key), np.array(all_score)
|
113 |
+
|
114 |
+
|
115 |
+
def bbox_xyxy2cs(bbox: np.ndarray,
|
116 |
+
padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
|
117 |
+
"""Transform the bbox format from (x,y,w,h) into (center, scale)
|
118 |
+
|
119 |
+
Args:
|
120 |
+
bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
|
121 |
+
as (left, top, right, bottom)
|
122 |
+
padding (float): BBox padding factor that will be multilied to scale.
|
123 |
+
Default: 1.0
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
tuple: A tuple containing center and scale.
|
127 |
+
- np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
|
128 |
+
(n, 2)
|
129 |
+
- np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
|
130 |
+
(n, 2)
|
131 |
+
"""
|
132 |
+
# convert single bbox from (4, ) to (1, 4)
|
133 |
+
dim = bbox.ndim
|
134 |
+
if dim == 1:
|
135 |
+
bbox = bbox[None, :]
|
136 |
+
|
137 |
+
# get bbox center and scale
|
138 |
+
x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
|
139 |
+
center = np.hstack([x1 + x2, y1 + y2]) * 0.5
|
140 |
+
scale = np.hstack([x2 - x1, y2 - y1]) * padding
|
141 |
+
|
142 |
+
if dim == 1:
|
143 |
+
center = center[0]
|
144 |
+
scale = scale[0]
|
145 |
+
|
146 |
+
return center, scale
|
147 |
+
|
148 |
+
|
149 |
+
def _fix_aspect_ratio(bbox_scale: np.ndarray,
|
150 |
+
aspect_ratio: float) -> np.ndarray:
|
151 |
+
"""Extend the scale to match the given aspect ratio.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
scale (np.ndarray): The image scale (w, h) in shape (2, )
|
155 |
+
aspect_ratio (float): The ratio of ``w/h``
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
np.ndarray: The reshaped image scale in (2, )
|
159 |
+
"""
|
160 |
+
w, h = np.hsplit(bbox_scale, [1])
|
161 |
+
bbox_scale = np.where(w > h * aspect_ratio,
|
162 |
+
np.hstack([w, w / aspect_ratio]),
|
163 |
+
np.hstack([h * aspect_ratio, h]))
|
164 |
+
return bbox_scale
|
165 |
+
|
166 |
+
|
167 |
+
def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
|
168 |
+
"""Rotate a point by an angle.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
|
172 |
+
angle_rad (float): rotation angle in radian
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
np.ndarray: Rotated point in shape (2, )
|
176 |
+
"""
|
177 |
+
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
|
178 |
+
rot_mat = np.array([[cs, -sn], [sn, cs]])
|
179 |
+
return rot_mat @ pt
|
180 |
+
|
181 |
+
|
182 |
+
def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
183 |
+
"""To calculate the affine matrix, three pairs of points are required. This
|
184 |
+
function is used to get the 3rd point, given 2D points a & b.
|
185 |
+
|
186 |
+
The 3rd point is defined by rotating vector `a - b` by 90 degrees
|
187 |
+
anticlockwise, using b as the rotation center.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
a (np.ndarray): The 1st point (x,y) in shape (2, )
|
191 |
+
b (np.ndarray): The 2nd point (x,y) in shape (2, )
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
np.ndarray: The 3rd point.
|
195 |
+
"""
|
196 |
+
direction = a - b
|
197 |
+
c = b + np.r_[-direction[1], direction[0]]
|
198 |
+
return c
|
199 |
+
|
200 |
+
|
201 |
+
def get_warp_matrix(center: np.ndarray,
|
202 |
+
scale: np.ndarray,
|
203 |
+
rot: float,
|
204 |
+
output_size: Tuple[int, int],
|
205 |
+
shift: Tuple[float, float] = (0., 0.),
|
206 |
+
inv: bool = False) -> np.ndarray:
|
207 |
+
"""Calculate the affine transformation matrix that can warp the bbox area
|
208 |
+
in the input image to the output size.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
212 |
+
scale (np.ndarray[2, ]): Scale of the bounding box
|
213 |
+
wrt [width, height].
|
214 |
+
rot (float): Rotation angle (degree).
|
215 |
+
output_size (np.ndarray[2, ] | list(2,)): Size of the
|
216 |
+
destination heatmaps.
|
217 |
+
shift (0-100%): Shift translation ratio wrt the width/height.
|
218 |
+
Default (0., 0.).
|
219 |
+
inv (bool): Option to inverse the affine transform direction.
|
220 |
+
(inv=False: src->dst or inv=True: dst->src)
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
np.ndarray: A 2x3 transformation matrix
|
224 |
+
"""
|
225 |
+
shift = np.array(shift)
|
226 |
+
src_w = scale[0]
|
227 |
+
dst_w = output_size[0]
|
228 |
+
dst_h = output_size[1]
|
229 |
+
|
230 |
+
# compute transformation matrix
|
231 |
+
rot_rad = np.deg2rad(rot)
|
232 |
+
src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
|
233 |
+
dst_dir = np.array([0., dst_w * -0.5])
|
234 |
+
|
235 |
+
# get four corners of the src rectangle in the original image
|
236 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
237 |
+
src[0, :] = center + scale * shift
|
238 |
+
src[1, :] = center + src_dir + scale * shift
|
239 |
+
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
|
240 |
+
|
241 |
+
# get four corners of the dst rectangle in the input image
|
242 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
243 |
+
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
244 |
+
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
245 |
+
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
246 |
+
|
247 |
+
if inv:
|
248 |
+
warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
249 |
+
else:
|
250 |
+
warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
251 |
+
|
252 |
+
return warp_mat
|
253 |
+
|
254 |
+
|
255 |
+
def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
|
256 |
+
img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
257 |
+
"""Get the bbox image as the model input by affine transform.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
input_size (dict): The input size of the model.
|
261 |
+
bbox_scale (dict): The bbox scale of the img.
|
262 |
+
bbox_center (dict): The bbox center of the img.
|
263 |
+
img (np.ndarray): The original image.
|
264 |
+
|
265 |
+
Returns:
|
266 |
+
tuple: A tuple containing center and scale.
|
267 |
+
- np.ndarray[float32]: img after affine transform.
|
268 |
+
- np.ndarray[float32]: bbox scale after affine transform.
|
269 |
+
"""
|
270 |
+
w, h = input_size
|
271 |
+
warp_size = (int(w), int(h))
|
272 |
+
|
273 |
+
# reshape bbox to fixed aspect ratio
|
274 |
+
bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
|
275 |
+
|
276 |
+
# get the affine matrix
|
277 |
+
center = bbox_center
|
278 |
+
scale = bbox_scale
|
279 |
+
rot = 0
|
280 |
+
warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
|
281 |
+
|
282 |
+
# do affine transform
|
283 |
+
img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
|
284 |
+
|
285 |
+
return img, bbox_scale
|
286 |
+
|
287 |
+
|
288 |
+
def get_simcc_maximum(simcc_x: np.ndarray,
|
289 |
+
simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
290 |
+
"""Get maximum response location and value from simcc representations.
|
291 |
+
|
292 |
+
Note:
|
293 |
+
instance number: N
|
294 |
+
num_keypoints: K
|
295 |
+
heatmap height: H
|
296 |
+
heatmap width: W
|
297 |
+
|
298 |
+
Args:
|
299 |
+
simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
|
300 |
+
simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
tuple:
|
304 |
+
- locs (np.ndarray): locations of maximum heatmap responses in shape
|
305 |
+
(K, 2) or (N, K, 2)
|
306 |
+
- vals (np.ndarray): values of maximum heatmap responses in shape
|
307 |
+
(K,) or (N, K)
|
308 |
+
"""
|
309 |
+
N, K, Wx = simcc_x.shape
|
310 |
+
simcc_x = simcc_x.reshape(N * K, -1)
|
311 |
+
simcc_y = simcc_y.reshape(N * K, -1)
|
312 |
+
|
313 |
+
# get maximum value locations
|
314 |
+
x_locs = np.argmax(simcc_x, axis=1)
|
315 |
+
y_locs = np.argmax(simcc_y, axis=1)
|
316 |
+
locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
|
317 |
+
max_val_x = np.amax(simcc_x, axis=1)
|
318 |
+
max_val_y = np.amax(simcc_y, axis=1)
|
319 |
+
|
320 |
+
# get maximum value across x and y axis
|
321 |
+
mask = max_val_x > max_val_y
|
322 |
+
max_val_x[mask] = max_val_y[mask]
|
323 |
+
vals = max_val_x
|
324 |
+
locs[vals <= 0.] = -1
|
325 |
+
|
326 |
+
# reshape
|
327 |
+
locs = locs.reshape(N, K, 2)
|
328 |
+
vals = vals.reshape(N, K)
|
329 |
+
|
330 |
+
return locs, vals
|
331 |
+
|
332 |
+
|
333 |
+
def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
|
334 |
+
simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
|
335 |
+
"""Modulate simcc distribution with Gaussian.
|
336 |
+
|
337 |
+
Args:
|
338 |
+
simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
|
339 |
+
simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
|
340 |
+
simcc_split_ratio (int): The split ratio of simcc.
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
tuple: A tuple containing center and scale.
|
344 |
+
- np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
|
345 |
+
- np.ndarray[float32]: scores in shape (K,) or (n, K)
|
346 |
+
"""
|
347 |
+
keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
|
348 |
+
keypoints /= simcc_split_ratio
|
349 |
+
|
350 |
+
return keypoints, scores
|
351 |
+
|
352 |
+
|
353 |
+
def inference_pose(session, out_bbox, oriImg):
|
354 |
+
h, w = session.get_inputs()[0].shape[2:]
|
355 |
+
model_input_size = (w, h)
|
356 |
+
resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
|
357 |
+
outputs = inference(session, resized_img)
|
358 |
+
keypoints, scores = postprocess(outputs, model_input_size, center, scale)
|
359 |
+
|
360 |
+
return keypoints, scores
|
src/flux/annotator/dwpose/util.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
|
7 |
+
eps = 0.01
|
8 |
+
|
9 |
+
|
10 |
+
def smart_resize(x, s):
|
11 |
+
Ht, Wt = s
|
12 |
+
if x.ndim == 2:
|
13 |
+
Ho, Wo = x.shape
|
14 |
+
Co = 1
|
15 |
+
else:
|
16 |
+
Ho, Wo, Co = x.shape
|
17 |
+
if Co == 3 or Co == 1:
|
18 |
+
k = float(Ht + Wt) / float(Ho + Wo)
|
19 |
+
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
20 |
+
else:
|
21 |
+
return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
|
22 |
+
|
23 |
+
|
24 |
+
def smart_resize_k(x, fx, fy):
|
25 |
+
if x.ndim == 2:
|
26 |
+
Ho, Wo = x.shape
|
27 |
+
Co = 1
|
28 |
+
else:
|
29 |
+
Ho, Wo, Co = x.shape
|
30 |
+
Ht, Wt = Ho * fy, Wo * fx
|
31 |
+
if Co == 3 or Co == 1:
|
32 |
+
k = float(Ht + Wt) / float(Ho + Wo)
|
33 |
+
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
34 |
+
else:
|
35 |
+
return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
|
36 |
+
|
37 |
+
|
38 |
+
def padRightDownCorner(img, stride, padValue):
|
39 |
+
h = img.shape[0]
|
40 |
+
w = img.shape[1]
|
41 |
+
|
42 |
+
pad = 4 * [None]
|
43 |
+
pad[0] = 0 # up
|
44 |
+
pad[1] = 0 # left
|
45 |
+
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
|
46 |
+
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
|
47 |
+
|
48 |
+
img_padded = img
|
49 |
+
pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
|
50 |
+
img_padded = np.concatenate((pad_up, img_padded), axis=0)
|
51 |
+
pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
|
52 |
+
img_padded = np.concatenate((pad_left, img_padded), axis=1)
|
53 |
+
pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
|
54 |
+
img_padded = np.concatenate((img_padded, pad_down), axis=0)
|
55 |
+
pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
|
56 |
+
img_padded = np.concatenate((img_padded, pad_right), axis=1)
|
57 |
+
|
58 |
+
return img_padded, pad
|
59 |
+
|
60 |
+
|
61 |
+
def transfer(model, model_weights):
|
62 |
+
transfered_model_weights = {}
|
63 |
+
for weights_name in model.state_dict().keys():
|
64 |
+
transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
|
65 |
+
return transfered_model_weights
|
66 |
+
|
67 |
+
|
68 |
+
def draw_bodypose(canvas, candidate, subset):
|
69 |
+
H, W, C = canvas.shape
|
70 |
+
candidate = np.array(candidate)
|
71 |
+
subset = np.array(subset)
|
72 |
+
|
73 |
+
stickwidth = 4
|
74 |
+
|
75 |
+
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
76 |
+
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
77 |
+
[1, 16], [16, 18], [3, 17], [6, 18]]
|
78 |
+
|
79 |
+
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
80 |
+
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
81 |
+
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
82 |
+
|
83 |
+
for i in range(17):
|
84 |
+
for n in range(len(subset)):
|
85 |
+
index = subset[n][np.array(limbSeq[i]) - 1]
|
86 |
+
if -1 in index:
|
87 |
+
continue
|
88 |
+
Y = candidate[index.astype(int), 0] * float(W)
|
89 |
+
X = candidate[index.astype(int), 1] * float(H)
|
90 |
+
mX = np.mean(X)
|
91 |
+
mY = np.mean(Y)
|
92 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
93 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
94 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
95 |
+
cv2.fillConvexPoly(canvas, polygon, colors[i])
|
96 |
+
|
97 |
+
canvas = (canvas * 0.6).astype(np.uint8)
|
98 |
+
|
99 |
+
for i in range(18):
|
100 |
+
for n in range(len(subset)):
|
101 |
+
index = int(subset[n][i])
|
102 |
+
if index == -1:
|
103 |
+
continue
|
104 |
+
x, y = candidate[index][0:2]
|
105 |
+
x = int(x * W)
|
106 |
+
y = int(y * H)
|
107 |
+
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
|
108 |
+
|
109 |
+
return canvas
|
110 |
+
|
111 |
+
|
112 |
+
def draw_handpose(canvas, all_hand_peaks):
|
113 |
+
H, W, C = canvas.shape
|
114 |
+
|
115 |
+
edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
|
116 |
+
[10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
|
117 |
+
|
118 |
+
for peaks in all_hand_peaks:
|
119 |
+
peaks = np.array(peaks)
|
120 |
+
|
121 |
+
for ie, e in enumerate(edges):
|
122 |
+
x1, y1 = peaks[e[0]]
|
123 |
+
x2, y2 = peaks[e[1]]
|
124 |
+
x1 = int(x1 * W)
|
125 |
+
y1 = int(y1 * H)
|
126 |
+
x2 = int(x2 * W)
|
127 |
+
y2 = int(y2 * H)
|
128 |
+
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
129 |
+
cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
|
130 |
+
|
131 |
+
for i, keyponit in enumerate(peaks):
|
132 |
+
x, y = keyponit
|
133 |
+
x = int(x * W)
|
134 |
+
y = int(y * H)
|
135 |
+
if x > eps and y > eps:
|
136 |
+
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
137 |
+
return canvas
|
138 |
+
|
139 |
+
|
140 |
+
def draw_facepose(canvas, all_lmks):
|
141 |
+
H, W, C = canvas.shape
|
142 |
+
for lmks in all_lmks:
|
143 |
+
lmks = np.array(lmks)
|
144 |
+
for lmk in lmks:
|
145 |
+
x, y = lmk
|
146 |
+
x = int(x * W)
|
147 |
+
y = int(y * H)
|
148 |
+
if x > eps and y > eps:
|
149 |
+
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
|
150 |
+
return canvas
|
151 |
+
|
152 |
+
|
153 |
+
# detect hand according to body pose keypoints
|
154 |
+
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
|
155 |
+
def handDetect(candidate, subset, oriImg):
|
156 |
+
# right hand: wrist 4, elbow 3, shoulder 2
|
157 |
+
# left hand: wrist 7, elbow 6, shoulder 5
|
158 |
+
ratioWristElbow = 0.33
|
159 |
+
detect_result = []
|
160 |
+
image_height, image_width = oriImg.shape[0:2]
|
161 |
+
for person in subset.astype(int):
|
162 |
+
# if any of three not detected
|
163 |
+
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
|
164 |
+
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
|
165 |
+
if not (has_left or has_right):
|
166 |
+
continue
|
167 |
+
hands = []
|
168 |
+
#left hand
|
169 |
+
if has_left:
|
170 |
+
left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
|
171 |
+
x1, y1 = candidate[left_shoulder_index][:2]
|
172 |
+
x2, y2 = candidate[left_elbow_index][:2]
|
173 |
+
x3, y3 = candidate[left_wrist_index][:2]
|
174 |
+
hands.append([x1, y1, x2, y2, x3, y3, True])
|
175 |
+
# right hand
|
176 |
+
if has_right:
|
177 |
+
right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
|
178 |
+
x1, y1 = candidate[right_shoulder_index][:2]
|
179 |
+
x2, y2 = candidate[right_elbow_index][:2]
|
180 |
+
x3, y3 = candidate[right_wrist_index][:2]
|
181 |
+
hands.append([x1, y1, x2, y2, x3, y3, False])
|
182 |
+
|
183 |
+
for x1, y1, x2, y2, x3, y3, is_left in hands:
|
184 |
+
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
|
185 |
+
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
|
186 |
+
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
|
187 |
+
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
|
188 |
+
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
|
189 |
+
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
|
190 |
+
x = x3 + ratioWristElbow * (x3 - x2)
|
191 |
+
y = y3 + ratioWristElbow * (y3 - y2)
|
192 |
+
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
|
193 |
+
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
|
194 |
+
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
|
195 |
+
# x-y refers to the center --> offset to topLeft point
|
196 |
+
# handRectangle.x -= handRectangle.width / 2.f;
|
197 |
+
# handRectangle.y -= handRectangle.height / 2.f;
|
198 |
+
x -= width / 2
|
199 |
+
y -= width / 2 # width = height
|
200 |
+
# overflow the image
|
201 |
+
if x < 0: x = 0
|
202 |
+
if y < 0: y = 0
|
203 |
+
width1 = width
|
204 |
+
width2 = width
|
205 |
+
if x + width > image_width: width1 = image_width - x
|
206 |
+
if y + width > image_height: width2 = image_height - y
|
207 |
+
width = min(width1, width2)
|
208 |
+
# the max hand box value is 20 pixels
|
209 |
+
if width >= 20:
|
210 |
+
detect_result.append([int(x), int(y), int(width), is_left])
|
211 |
+
|
212 |
+
'''
|
213 |
+
return value: [[x, y, w, True if left hand else False]].
|
214 |
+
width=height since the network require squared input.
|
215 |
+
x, y is the coordinate of top left
|
216 |
+
'''
|
217 |
+
return detect_result
|
218 |
+
|
219 |
+
|
220 |
+
# Written by Lvmin
|
221 |
+
def faceDetect(candidate, subset, oriImg):
|
222 |
+
# left right eye ear 14 15 16 17
|
223 |
+
detect_result = []
|
224 |
+
image_height, image_width = oriImg.shape[0:2]
|
225 |
+
for person in subset.astype(int):
|
226 |
+
has_head = person[0] > -1
|
227 |
+
if not has_head:
|
228 |
+
continue
|
229 |
+
|
230 |
+
has_left_eye = person[14] > -1
|
231 |
+
has_right_eye = person[15] > -1
|
232 |
+
has_left_ear = person[16] > -1
|
233 |
+
has_right_ear = person[17] > -1
|
234 |
+
|
235 |
+
if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
|
236 |
+
continue
|
237 |
+
|
238 |
+
head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
|
239 |
+
|
240 |
+
width = 0.0
|
241 |
+
x0, y0 = candidate[head][:2]
|
242 |
+
|
243 |
+
if has_left_eye:
|
244 |
+
x1, y1 = candidate[left_eye][:2]
|
245 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
246 |
+
width = max(width, d * 3.0)
|
247 |
+
|
248 |
+
if has_right_eye:
|
249 |
+
x1, y1 = candidate[right_eye][:2]
|
250 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
251 |
+
width = max(width, d * 3.0)
|
252 |
+
|
253 |
+
if has_left_ear:
|
254 |
+
x1, y1 = candidate[left_ear][:2]
|
255 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
256 |
+
width = max(width, d * 1.5)
|
257 |
+
|
258 |
+
if has_right_ear:
|
259 |
+
x1, y1 = candidate[right_ear][:2]
|
260 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
261 |
+
width = max(width, d * 1.5)
|
262 |
+
|
263 |
+
x, y = x0, y0
|
264 |
+
|
265 |
+
x -= width
|
266 |
+
y -= width
|
267 |
+
|
268 |
+
if x < 0:
|
269 |
+
x = 0
|
270 |
+
|
271 |
+
if y < 0:
|
272 |
+
y = 0
|
273 |
+
|
274 |
+
width1 = width * 2
|
275 |
+
width2 = width * 2
|
276 |
+
|
277 |
+
if x + width > image_width:
|
278 |
+
width1 = image_width - x
|
279 |
+
|
280 |
+
if y + width > image_height:
|
281 |
+
width2 = image_height - y
|
282 |
+
|
283 |
+
width = min(width1, width2)
|
284 |
+
|
285 |
+
if width >= 20:
|
286 |
+
detect_result.append([int(x), int(y), int(width)])
|
287 |
+
|
288 |
+
return detect_result
|
289 |
+
|
290 |
+
|
291 |
+
# get max index of 2d array
|
292 |
+
def npmax(array):
|
293 |
+
arrayindex = array.argmax(1)
|
294 |
+
arrayvalue = array.max(1)
|
295 |
+
i = arrayvalue.argmax()
|
296 |
+
j = arrayindex[i]
|
297 |
+
return i, j
|
src/flux/annotator/dwpose/wholebody.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import onnxruntime as ort
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
from .onnxdet import inference_detector
|
7 |
+
from .onnxpose import inference_pose
|
8 |
+
|
9 |
+
|
10 |
+
class Wholebody:
|
11 |
+
def __init__(self, device="cuda:0"):
|
12 |
+
providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']
|
13 |
+
onnx_det = hf_hub_download("yzd-v/DWPose", "yolox_l.onnx")
|
14 |
+
onnx_pose = hf_hub_download("yzd-v/DWPose", "dw-ll_ucoco_384.onnx")
|
15 |
+
|
16 |
+
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
17 |
+
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
18 |
+
|
19 |
+
def __call__(self, oriImg):
|
20 |
+
det_result = inference_detector(self.session_det, oriImg)
|
21 |
+
keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
|
22 |
+
|
23 |
+
keypoints_info = np.concatenate(
|
24 |
+
(keypoints, scores[..., None]), axis=-1)
|
25 |
+
# compute neck joint
|
26 |
+
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
|
27 |
+
# neck score when visualizing pred
|
28 |
+
neck[:, 2:4] = np.logical_and(
|
29 |
+
keypoints_info[:, 5, 2:4] > 0.3,
|
30 |
+
keypoints_info[:, 6, 2:4] > 0.3).astype(int)
|
31 |
+
new_keypoints_info = np.insert(
|
32 |
+
keypoints_info, 17, neck, axis=1)
|
33 |
+
mmpose_idx = [
|
34 |
+
17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
|
35 |
+
]
|
36 |
+
openpose_idx = [
|
37 |
+
1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
|
38 |
+
]
|
39 |
+
new_keypoints_info[:, openpose_idx] = \
|
40 |
+
new_keypoints_info[:, mmpose_idx]
|
41 |
+
keypoints_info = new_keypoints_info
|
42 |
+
|
43 |
+
keypoints, scores = keypoints_info[
|
44 |
+
..., :2], keypoints_info[..., 2]
|
45 |
+
|
46 |
+
return keypoints, scores
|
47 |
+
|
48 |
+
|
src/flux/annotator/hed/__init__.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is an improved version and model of HED edge detection with Apache License, Version 2.0.
|
2 |
+
# Please use this implementation in your products
|
3 |
+
# This implementation may produce slightly different results from Saining Xie's official implementations,
|
4 |
+
# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
|
5 |
+
# Different from official models and other implementations, this is an RGB-input model (rather than BGR)
|
6 |
+
# and in this way it works better for gradio's RGB protocol
|
7 |
+
|
8 |
+
import os
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
+
from einops import rearrange
|
15 |
+
from ...annotator.util import annotator_ckpts_path
|
16 |
+
|
17 |
+
|
18 |
+
class DoubleConvBlock(torch.nn.Module):
|
19 |
+
def __init__(self, input_channel, output_channel, layer_number):
|
20 |
+
super().__init__()
|
21 |
+
self.convs = torch.nn.Sequential()
|
22 |
+
self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
23 |
+
for i in range(1, layer_number):
|
24 |
+
self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
25 |
+
self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
|
26 |
+
|
27 |
+
def __call__(self, x, down_sampling=False):
|
28 |
+
h = x
|
29 |
+
if down_sampling:
|
30 |
+
h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
|
31 |
+
for conv in self.convs:
|
32 |
+
h = conv(h)
|
33 |
+
h = torch.nn.functional.relu(h)
|
34 |
+
return h, self.projection(h)
|
35 |
+
|
36 |
+
|
37 |
+
class ControlNetHED_Apache2(torch.nn.Module):
|
38 |
+
def __init__(self):
|
39 |
+
super().__init__()
|
40 |
+
self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
|
41 |
+
self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
|
42 |
+
self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
|
43 |
+
self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
|
44 |
+
self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
|
45 |
+
self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
|
46 |
+
|
47 |
+
def __call__(self, x):
|
48 |
+
h = x - self.norm
|
49 |
+
h, projection1 = self.block1(h)
|
50 |
+
h, projection2 = self.block2(h, down_sampling=True)
|
51 |
+
h, projection3 = self.block3(h, down_sampling=True)
|
52 |
+
h, projection4 = self.block4(h, down_sampling=True)
|
53 |
+
h, projection5 = self.block5(h, down_sampling=True)
|
54 |
+
return projection1, projection2, projection3, projection4, projection5
|
55 |
+
|
56 |
+
|
57 |
+
class HEDdetector:
|
58 |
+
def __init__(self):
|
59 |
+
modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth")
|
60 |
+
if not os.path.exists(modelpath):
|
61 |
+
modelpath = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth")
|
62 |
+
self.netNetwork = ControlNetHED_Apache2().float().cuda().eval()
|
63 |
+
self.netNetwork.load_state_dict(torch.load(modelpath))
|
64 |
+
|
65 |
+
def __call__(self, input_image):
|
66 |
+
assert input_image.ndim == 3
|
67 |
+
H, W, C = input_image.shape
|
68 |
+
with torch.no_grad():
|
69 |
+
image_hed = torch.from_numpy(input_image.copy()).float().cuda()
|
70 |
+
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
71 |
+
edges = self.netNetwork(image_hed)
|
72 |
+
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
|
73 |
+
edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
|
74 |
+
edges = np.stack(edges, axis=2)
|
75 |
+
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
|
76 |
+
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
|
77 |
+
return edge
|
78 |
+
|
79 |
+
|
80 |
+
def nms(x, t, s):
|
81 |
+
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
|
82 |
+
|
83 |
+
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
84 |
+
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
85 |
+
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
86 |
+
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
87 |
+
|
88 |
+
y = np.zeros_like(x)
|
89 |
+
|
90 |
+
for f in [f1, f2, f3, f4]:
|
91 |
+
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
92 |
+
|
93 |
+
z = np.zeros_like(y, dtype=np.uint8)
|
94 |
+
z[y > t] = 255
|
95 |
+
return z
|
src/flux/annotator/midas/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
src/flux/annotator/midas/__init__.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Midas Depth Estimation
|
2 |
+
# From https://github.com/isl-org/MiDaS
|
3 |
+
# MIT LICENSE
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from einops import rearrange
|
10 |
+
from .api import MiDaSInference
|
11 |
+
|
12 |
+
|
13 |
+
class MidasDetector:
|
14 |
+
def __init__(self):
|
15 |
+
self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
|
16 |
+
|
17 |
+
def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
|
18 |
+
assert input_image.ndim == 3
|
19 |
+
image_depth = input_image
|
20 |
+
with torch.no_grad():
|
21 |
+
image_depth = torch.from_numpy(image_depth).float().cuda()
|
22 |
+
image_depth = image_depth / 127.5 - 1.0
|
23 |
+
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
24 |
+
depth = self.model(image_depth)[0]
|
25 |
+
|
26 |
+
depth_pt = depth.clone()
|
27 |
+
depth_pt -= torch.min(depth_pt)
|
28 |
+
depth_pt /= torch.max(depth_pt)
|
29 |
+
depth_pt = depth_pt.cpu().numpy()
|
30 |
+
depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
|
31 |
+
|
32 |
+
depth_np = depth.cpu().numpy()
|
33 |
+
x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
|
34 |
+
y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
|
35 |
+
z = np.ones_like(x) * a
|
36 |
+
x[depth_pt < bg_th] = 0
|
37 |
+
y[depth_pt < bg_th] = 0
|
38 |
+
normal = np.stack([x, y, z], axis=2)
|
39 |
+
normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
|
40 |
+
normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
|
41 |
+
|
42 |
+
return depth_image, normal_image
|
src/flux/annotator/midas/api.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# based on https://github.com/isl-org/MiDaS
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision.transforms import Compose
|
8 |
+
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
|
11 |
+
from .midas.dpt_depth import DPTDepthModel
|
12 |
+
from .midas.midas_net import MidasNet
|
13 |
+
from .midas.midas_net_custom import MidasNet_small
|
14 |
+
from .midas.transforms import Resize, NormalizeImage, PrepareForNet
|
15 |
+
from ...annotator.util import annotator_ckpts_path
|
16 |
+
|
17 |
+
|
18 |
+
ISL_PATHS = {
|
19 |
+
"dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
|
20 |
+
"dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
|
21 |
+
"midas_v21": "",
|
22 |
+
"midas_v21_small": "",
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
def disabled_train(self, mode=True):
|
27 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
28 |
+
does not change anymore."""
|
29 |
+
return self
|
30 |
+
|
31 |
+
|
32 |
+
def load_midas_transform(model_type):
|
33 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
34 |
+
# load transform only
|
35 |
+
if model_type == "dpt_large": # DPT-Large
|
36 |
+
net_w, net_h = 384, 384
|
37 |
+
resize_mode = "minimal"
|
38 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
39 |
+
|
40 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
41 |
+
net_w, net_h = 384, 384
|
42 |
+
resize_mode = "minimal"
|
43 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
44 |
+
|
45 |
+
elif model_type == "midas_v21":
|
46 |
+
net_w, net_h = 384, 384
|
47 |
+
resize_mode = "upper_bound"
|
48 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
49 |
+
|
50 |
+
elif model_type == "midas_v21_small":
|
51 |
+
net_w, net_h = 256, 256
|
52 |
+
resize_mode = "upper_bound"
|
53 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
54 |
+
|
55 |
+
else:
|
56 |
+
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
57 |
+
|
58 |
+
transform = Compose(
|
59 |
+
[
|
60 |
+
Resize(
|
61 |
+
net_w,
|
62 |
+
net_h,
|
63 |
+
resize_target=None,
|
64 |
+
keep_aspect_ratio=True,
|
65 |
+
ensure_multiple_of=32,
|
66 |
+
resize_method=resize_mode,
|
67 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
68 |
+
),
|
69 |
+
normalization,
|
70 |
+
PrepareForNet(),
|
71 |
+
]
|
72 |
+
)
|
73 |
+
|
74 |
+
return transform
|
75 |
+
|
76 |
+
|
77 |
+
def load_model(model_type):
|
78 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
79 |
+
# load network
|
80 |
+
model_path = ISL_PATHS[model_type]
|
81 |
+
if model_type == "dpt_large": # DPT-Large
|
82 |
+
model = DPTDepthModel(
|
83 |
+
path=model_path,
|
84 |
+
backbone="vitl16_384",
|
85 |
+
non_negative=True,
|
86 |
+
)
|
87 |
+
net_w, net_h = 384, 384
|
88 |
+
resize_mode = "minimal"
|
89 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
90 |
+
|
91 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
92 |
+
if not os.path.exists(model_path):
|
93 |
+
model_path = hf_hub_download("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt")
|
94 |
+
|
95 |
+
model = DPTDepthModel(
|
96 |
+
path=model_path,
|
97 |
+
backbone="vitb_rn50_384",
|
98 |
+
non_negative=True,
|
99 |
+
)
|
100 |
+
net_w, net_h = 384, 384
|
101 |
+
resize_mode = "minimal"
|
102 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
103 |
+
|
104 |
+
elif model_type == "midas_v21":
|
105 |
+
model = MidasNet(model_path, non_negative=True)
|
106 |
+
net_w, net_h = 384, 384
|
107 |
+
resize_mode = "upper_bound"
|
108 |
+
normalization = NormalizeImage(
|
109 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
110 |
+
)
|
111 |
+
|
112 |
+
elif model_type == "midas_v21_small":
|
113 |
+
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
|
114 |
+
non_negative=True, blocks={'expand': True})
|
115 |
+
net_w, net_h = 256, 256
|
116 |
+
resize_mode = "upper_bound"
|
117 |
+
normalization = NormalizeImage(
|
118 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
119 |
+
)
|
120 |
+
|
121 |
+
else:
|
122 |
+
print(f"model_type '{model_type}' not implemented, use: --model_type large")
|
123 |
+
assert False
|
124 |
+
|
125 |
+
transform = Compose(
|
126 |
+
[
|
127 |
+
Resize(
|
128 |
+
net_w,
|
129 |
+
net_h,
|
130 |
+
resize_target=None,
|
131 |
+
keep_aspect_ratio=True,
|
132 |
+
ensure_multiple_of=32,
|
133 |
+
resize_method=resize_mode,
|
134 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
135 |
+
),
|
136 |
+
normalization,
|
137 |
+
PrepareForNet(),
|
138 |
+
]
|
139 |
+
)
|
140 |
+
|
141 |
+
return model.eval(), transform
|
142 |
+
|
143 |
+
|
144 |
+
class MiDaSInference(nn.Module):
|
145 |
+
MODEL_TYPES_TORCH_HUB = [
|
146 |
+
"DPT_Large",
|
147 |
+
"DPT_Hybrid",
|
148 |
+
"MiDaS_small"
|
149 |
+
]
|
150 |
+
MODEL_TYPES_ISL = [
|
151 |
+
"dpt_large",
|
152 |
+
"dpt_hybrid",
|
153 |
+
"midas_v21",
|
154 |
+
"midas_v21_small",
|
155 |
+
]
|
156 |
+
|
157 |
+
def __init__(self, model_type):
|
158 |
+
super().__init__()
|
159 |
+
assert (model_type in self.MODEL_TYPES_ISL)
|
160 |
+
model, _ = load_model(model_type)
|
161 |
+
self.model = model
|
162 |
+
self.model.train = disabled_train
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
with torch.no_grad():
|
166 |
+
prediction = self.model(x)
|
167 |
+
return prediction
|
168 |
+
|
src/flux/annotator/midas/midas/__init__.py
ADDED
File without changes
|
src/flux/annotator/midas/midas/base_model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class BaseModel(torch.nn.Module):
|
5 |
+
def load(self, path):
|
6 |
+
"""Load model from file.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
path (str): file path
|
10 |
+
"""
|
11 |
+
parameters = torch.load(path, map_location=torch.device('cpu'))
|
12 |
+
|
13 |
+
if "optimizer" in parameters:
|
14 |
+
parameters = parameters["model"]
|
15 |
+
|
16 |
+
self.load_state_dict(parameters)
|
src/flux/annotator/midas/midas/blocks.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .vit import (
|
5 |
+
_make_pretrained_vitb_rn50_384,
|
6 |
+
_make_pretrained_vitl16_384,
|
7 |
+
_make_pretrained_vitb16_384,
|
8 |
+
forward_vit,
|
9 |
+
)
|
10 |
+
|
11 |
+
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
|
12 |
+
if backbone == "vitl16_384":
|
13 |
+
pretrained = _make_pretrained_vitl16_384(
|
14 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
15 |
+
)
|
16 |
+
scratch = _make_scratch(
|
17 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
18 |
+
) # ViT-L/16 - 85.0% Top1 (backbone)
|
19 |
+
elif backbone == "vitb_rn50_384":
|
20 |
+
pretrained = _make_pretrained_vitb_rn50_384(
|
21 |
+
use_pretrained,
|
22 |
+
hooks=hooks,
|
23 |
+
use_vit_only=use_vit_only,
|
24 |
+
use_readout=use_readout,
|
25 |
+
)
|
26 |
+
scratch = _make_scratch(
|
27 |
+
[256, 512, 768, 768], features, groups=groups, expand=expand
|
28 |
+
) # ViT-H/16 - 85.0% Top1 (backbone)
|
29 |
+
elif backbone == "vitb16_384":
|
30 |
+
pretrained = _make_pretrained_vitb16_384(
|
31 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
32 |
+
)
|
33 |
+
scratch = _make_scratch(
|
34 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
35 |
+
) # ViT-B/16 - 84.6% Top1 (backbone)
|
36 |
+
elif backbone == "resnext101_wsl":
|
37 |
+
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
38 |
+
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
39 |
+
elif backbone == "efficientnet_lite3":
|
40 |
+
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
41 |
+
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
42 |
+
else:
|
43 |
+
print(f"Backbone '{backbone}' not implemented")
|
44 |
+
assert False
|
45 |
+
|
46 |
+
return pretrained, scratch
|
47 |
+
|
48 |
+
|
49 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
50 |
+
scratch = nn.Module()
|
51 |
+
|
52 |
+
out_shape1 = out_shape
|
53 |
+
out_shape2 = out_shape
|
54 |
+
out_shape3 = out_shape
|
55 |
+
out_shape4 = out_shape
|
56 |
+
if expand==True:
|
57 |
+
out_shape1 = out_shape
|
58 |
+
out_shape2 = out_shape*2
|
59 |
+
out_shape3 = out_shape*4
|
60 |
+
out_shape4 = out_shape*8
|
61 |
+
|
62 |
+
scratch.layer1_rn = nn.Conv2d(
|
63 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
64 |
+
)
|
65 |
+
scratch.layer2_rn = nn.Conv2d(
|
66 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
67 |
+
)
|
68 |
+
scratch.layer3_rn = nn.Conv2d(
|
69 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
70 |
+
)
|
71 |
+
scratch.layer4_rn = nn.Conv2d(
|
72 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
73 |
+
)
|
74 |
+
|
75 |
+
return scratch
|
76 |
+
|
77 |
+
|
78 |
+
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
79 |
+
efficientnet = torch.hub.load(
|
80 |
+
"rwightman/gen-efficientnet-pytorch",
|
81 |
+
"tf_efficientnet_lite3",
|
82 |
+
pretrained=use_pretrained,
|
83 |
+
exportable=exportable
|
84 |
+
)
|
85 |
+
return _make_efficientnet_backbone(efficientnet)
|
86 |
+
|
87 |
+
|
88 |
+
def _make_efficientnet_backbone(effnet):
|
89 |
+
pretrained = nn.Module()
|
90 |
+
|
91 |
+
pretrained.layer1 = nn.Sequential(
|
92 |
+
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
93 |
+
)
|
94 |
+
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
95 |
+
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
96 |
+
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
97 |
+
|
98 |
+
return pretrained
|
99 |
+
|
100 |
+
|
101 |
+
def _make_resnet_backbone(resnet):
|
102 |
+
pretrained = nn.Module()
|
103 |
+
pretrained.layer1 = nn.Sequential(
|
104 |
+
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
105 |
+
)
|
106 |
+
|
107 |
+
pretrained.layer2 = resnet.layer2
|
108 |
+
pretrained.layer3 = resnet.layer3
|
109 |
+
pretrained.layer4 = resnet.layer4
|
110 |
+
|
111 |
+
return pretrained
|
112 |
+
|
113 |
+
|
114 |
+
def _make_pretrained_resnext101_wsl(use_pretrained):
|
115 |
+
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
116 |
+
return _make_resnet_backbone(resnet)
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
class Interpolate(nn.Module):
|
121 |
+
"""Interpolation module.
|
122 |
+
"""
|
123 |
+
|
124 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
125 |
+
"""Init.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
scale_factor (float): scaling
|
129 |
+
mode (str): interpolation mode
|
130 |
+
"""
|
131 |
+
super(Interpolate, self).__init__()
|
132 |
+
|
133 |
+
self.interp = nn.functional.interpolate
|
134 |
+
self.scale_factor = scale_factor
|
135 |
+
self.mode = mode
|
136 |
+
self.align_corners = align_corners
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
"""Forward pass.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
x (tensor): input
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
tensor: interpolated data
|
146 |
+
"""
|
147 |
+
|
148 |
+
x = self.interp(
|
149 |
+
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
150 |
+
)
|
151 |
+
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
class ResidualConvUnit(nn.Module):
|
156 |
+
"""Residual convolution module.
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(self, features):
|
160 |
+
"""Init.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
features (int): number of features
|
164 |
+
"""
|
165 |
+
super().__init__()
|
166 |
+
|
167 |
+
self.conv1 = nn.Conv2d(
|
168 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
169 |
+
)
|
170 |
+
|
171 |
+
self.conv2 = nn.Conv2d(
|
172 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
173 |
+
)
|
174 |
+
|
175 |
+
self.relu = nn.ReLU(inplace=True)
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
"""Forward pass.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
x (tensor): input
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
tensor: output
|
185 |
+
"""
|
186 |
+
out = self.relu(x)
|
187 |
+
out = self.conv1(out)
|
188 |
+
out = self.relu(out)
|
189 |
+
out = self.conv2(out)
|
190 |
+
|
191 |
+
return out + x
|
192 |
+
|
193 |
+
|
194 |
+
class FeatureFusionBlock(nn.Module):
|
195 |
+
"""Feature fusion block.
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(self, features):
|
199 |
+
"""Init.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
features (int): number of features
|
203 |
+
"""
|
204 |
+
super(FeatureFusionBlock, self).__init__()
|
205 |
+
|
206 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
207 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
208 |
+
|
209 |
+
def forward(self, *xs):
|
210 |
+
"""Forward pass.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
tensor: output
|
214 |
+
"""
|
215 |
+
output = xs[0]
|
216 |
+
|
217 |
+
if len(xs) == 2:
|
218 |
+
output += self.resConfUnit1(xs[1])
|
219 |
+
|
220 |
+
output = self.resConfUnit2(output)
|
221 |
+
|
222 |
+
output = nn.functional.interpolate(
|
223 |
+
output, scale_factor=2, mode="bilinear", align_corners=True
|
224 |
+
)
|
225 |
+
|
226 |
+
return output
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
class ResidualConvUnit_custom(nn.Module):
|
232 |
+
"""Residual convolution module.
|
233 |
+
"""
|
234 |
+
|
235 |
+
def __init__(self, features, activation, bn):
|
236 |
+
"""Init.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
features (int): number of features
|
240 |
+
"""
|
241 |
+
super().__init__()
|
242 |
+
|
243 |
+
self.bn = bn
|
244 |
+
|
245 |
+
self.groups=1
|
246 |
+
|
247 |
+
self.conv1 = nn.Conv2d(
|
248 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
249 |
+
)
|
250 |
+
|
251 |
+
self.conv2 = nn.Conv2d(
|
252 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
253 |
+
)
|
254 |
+
|
255 |
+
if self.bn==True:
|
256 |
+
self.bn1 = nn.BatchNorm2d(features)
|
257 |
+
self.bn2 = nn.BatchNorm2d(features)
|
258 |
+
|
259 |
+
self.activation = activation
|
260 |
+
|
261 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
"""Forward pass.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
x (tensor): input
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
tensor: output
|
271 |
+
"""
|
272 |
+
|
273 |
+
out = self.activation(x)
|
274 |
+
out = self.conv1(out)
|
275 |
+
if self.bn==True:
|
276 |
+
out = self.bn1(out)
|
277 |
+
|
278 |
+
out = self.activation(out)
|
279 |
+
out = self.conv2(out)
|
280 |
+
if self.bn==True:
|
281 |
+
out = self.bn2(out)
|
282 |
+
|
283 |
+
if self.groups > 1:
|
284 |
+
out = self.conv_merge(out)
|
285 |
+
|
286 |
+
return self.skip_add.add(out, x)
|
287 |
+
|
288 |
+
# return out + x
|
289 |
+
|
290 |
+
|
291 |
+
class FeatureFusionBlock_custom(nn.Module):
|
292 |
+
"""Feature fusion block.
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
|
296 |
+
"""Init.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
features (int): number of features
|
300 |
+
"""
|
301 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
302 |
+
|
303 |
+
self.deconv = deconv
|
304 |
+
self.align_corners = align_corners
|
305 |
+
|
306 |
+
self.groups=1
|
307 |
+
|
308 |
+
self.expand = expand
|
309 |
+
out_features = features
|
310 |
+
if self.expand==True:
|
311 |
+
out_features = features//2
|
312 |
+
|
313 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
314 |
+
|
315 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
316 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
317 |
+
|
318 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
319 |
+
|
320 |
+
def forward(self, *xs):
|
321 |
+
"""Forward pass.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
tensor: output
|
325 |
+
"""
|
326 |
+
output = xs[0]
|
327 |
+
|
328 |
+
if len(xs) == 2:
|
329 |
+
res = self.resConfUnit1(xs[1])
|
330 |
+
output = self.skip_add.add(output, res)
|
331 |
+
# output += res
|
332 |
+
|
333 |
+
output = self.resConfUnit2(output)
|
334 |
+
|
335 |
+
output = nn.functional.interpolate(
|
336 |
+
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
337 |
+
)
|
338 |
+
|
339 |
+
output = self.out_conv(output)
|
340 |
+
|
341 |
+
return output
|
342 |
+
|
src/flux/annotator/midas/midas/dpt_depth.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .base_model import BaseModel
|
6 |
+
from .blocks import (
|
7 |
+
FeatureFusionBlock,
|
8 |
+
FeatureFusionBlock_custom,
|
9 |
+
Interpolate,
|
10 |
+
_make_encoder,
|
11 |
+
forward_vit,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def _make_fusion_block(features, use_bn):
|
16 |
+
return FeatureFusionBlock_custom(
|
17 |
+
features,
|
18 |
+
nn.ReLU(False),
|
19 |
+
deconv=False,
|
20 |
+
bn=use_bn,
|
21 |
+
expand=False,
|
22 |
+
align_corners=True,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class DPT(BaseModel):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
head,
|
30 |
+
features=256,
|
31 |
+
backbone="vitb_rn50_384",
|
32 |
+
readout="project",
|
33 |
+
channels_last=False,
|
34 |
+
use_bn=False,
|
35 |
+
):
|
36 |
+
|
37 |
+
super(DPT, self).__init__()
|
38 |
+
|
39 |
+
self.channels_last = channels_last
|
40 |
+
|
41 |
+
hooks = {
|
42 |
+
"vitb_rn50_384": [0, 1, 8, 11],
|
43 |
+
"vitb16_384": [2, 5, 8, 11],
|
44 |
+
"vitl16_384": [5, 11, 17, 23],
|
45 |
+
}
|
46 |
+
|
47 |
+
# Instantiate backbone and reassemble blocks
|
48 |
+
self.pretrained, self.scratch = _make_encoder(
|
49 |
+
backbone,
|
50 |
+
features,
|
51 |
+
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
52 |
+
groups=1,
|
53 |
+
expand=False,
|
54 |
+
exportable=False,
|
55 |
+
hooks=hooks[backbone],
|
56 |
+
use_readout=readout,
|
57 |
+
)
|
58 |
+
|
59 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
60 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
61 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
62 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
63 |
+
|
64 |
+
self.scratch.output_conv = head
|
65 |
+
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
if self.channels_last == True:
|
69 |
+
x.contiguous(memory_format=torch.channels_last)
|
70 |
+
|
71 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
72 |
+
|
73 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
74 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
75 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
76 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
77 |
+
|
78 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
79 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
80 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
81 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
82 |
+
|
83 |
+
out = self.scratch.output_conv(path_1)
|
84 |
+
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
class DPTDepthModel(DPT):
|
89 |
+
def __init__(self, path=None, non_negative=True, **kwargs):
|
90 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
91 |
+
|
92 |
+
head = nn.Sequential(
|
93 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
94 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
95 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
96 |
+
nn.ReLU(True),
|
97 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
98 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
99 |
+
nn.Identity(),
|
100 |
+
)
|
101 |
+
|
102 |
+
super().__init__(head, **kwargs)
|
103 |
+
|
104 |
+
if path is not None:
|
105 |
+
self.load(path)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return super().forward(x).squeeze(dim=1)
|
109 |
+
|
src/flux/annotator/midas/midas/midas_net.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=256, non_negative=True):
|
17 |
+
"""Init.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
path (str, optional): Path to saved model. Defaults to None.
|
21 |
+
features (int, optional): Number of features. Defaults to 256.
|
22 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
23 |
+
"""
|
24 |
+
print("Loading weights: ", path)
|
25 |
+
|
26 |
+
super(MidasNet, self).__init__()
|
27 |
+
|
28 |
+
use_pretrained = False if path is None else True
|
29 |
+
|
30 |
+
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
|
31 |
+
|
32 |
+
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
33 |
+
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
34 |
+
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
35 |
+
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
36 |
+
|
37 |
+
self.scratch.output_conv = nn.Sequential(
|
38 |
+
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
39 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
40 |
+
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
41 |
+
nn.ReLU(True),
|
42 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
43 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
44 |
+
)
|
45 |
+
|
46 |
+
if path:
|
47 |
+
self.load(path)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
"""Forward pass.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
x (tensor): input data (image)
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
tensor: depth
|
57 |
+
"""
|
58 |
+
|
59 |
+
layer_1 = self.pretrained.layer1(x)
|
60 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
61 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
62 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
63 |
+
|
64 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
65 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
66 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
67 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
68 |
+
|
69 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
70 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
71 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
72 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
73 |
+
|
74 |
+
out = self.scratch.output_conv(path_1)
|
75 |
+
|
76 |
+
return torch.squeeze(out, dim=1)
|
src/flux/annotator/midas/midas/midas_net_custom.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet_small(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
|
17 |
+
blocks={'expand': True}):
|
18 |
+
"""Init.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
path (str, optional): Path to saved model. Defaults to None.
|
22 |
+
features (int, optional): Number of features. Defaults to 256.
|
23 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
24 |
+
"""
|
25 |
+
print("Loading weights: ", path)
|
26 |
+
|
27 |
+
super(MidasNet_small, self).__init__()
|
28 |
+
|
29 |
+
use_pretrained = False if path else True
|
30 |
+
|
31 |
+
self.channels_last = channels_last
|
32 |
+
self.blocks = blocks
|
33 |
+
self.backbone = backbone
|
34 |
+
|
35 |
+
self.groups = 1
|
36 |
+
|
37 |
+
features1=features
|
38 |
+
features2=features
|
39 |
+
features3=features
|
40 |
+
features4=features
|
41 |
+
self.expand = False
|
42 |
+
if "expand" in self.blocks and self.blocks['expand'] == True:
|
43 |
+
self.expand = True
|
44 |
+
features1=features
|
45 |
+
features2=features*2
|
46 |
+
features3=features*4
|
47 |
+
features4=features*8
|
48 |
+
|
49 |
+
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
|
50 |
+
|
51 |
+
self.scratch.activation = nn.ReLU(False)
|
52 |
+
|
53 |
+
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
54 |
+
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
55 |
+
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
56 |
+
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
|
57 |
+
|
58 |
+
|
59 |
+
self.scratch.output_conv = nn.Sequential(
|
60 |
+
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
|
61 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
62 |
+
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
|
63 |
+
self.scratch.activation,
|
64 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
65 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
66 |
+
nn.Identity(),
|
67 |
+
)
|
68 |
+
|
69 |
+
if path:
|
70 |
+
self.load(path)
|
71 |
+
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
"""Forward pass.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
x (tensor): input data (image)
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
tensor: depth
|
81 |
+
"""
|
82 |
+
if self.channels_last==True:
|
83 |
+
print("self.channels_last = ", self.channels_last)
|
84 |
+
x.contiguous(memory_format=torch.channels_last)
|
85 |
+
|
86 |
+
|
87 |
+
layer_1 = self.pretrained.layer1(x)
|
88 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
89 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
90 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
91 |
+
|
92 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
93 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
94 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
95 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
96 |
+
|
97 |
+
|
98 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
99 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
100 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
101 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
102 |
+
|
103 |
+
out = self.scratch.output_conv(path_1)
|
104 |
+
|
105 |
+
return torch.squeeze(out, dim=1)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def fuse_model(m):
|
110 |
+
prev_previous_type = nn.Identity()
|
111 |
+
prev_previous_name = ''
|
112 |
+
previous_type = nn.Identity()
|
113 |
+
previous_name = ''
|
114 |
+
for name, module in m.named_modules():
|
115 |
+
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
|
116 |
+
# print("FUSED ", prev_previous_name, previous_name, name)
|
117 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
|
118 |
+
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
119 |
+
# print("FUSED ", prev_previous_name, previous_name)
|
120 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
|
121 |
+
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
122 |
+
# print("FUSED ", previous_name, name)
|
123 |
+
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
124 |
+
|
125 |
+
prev_previous_type = previous_type
|
126 |
+
prev_previous_name = previous_name
|
127 |
+
previous_type = type(module)
|
128 |
+
previous_name = name
|
src/flux/annotator/midas/midas/transforms.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
7 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
sample (dict): sample
|
11 |
+
size (tuple): image size
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
tuple: new size
|
15 |
+
"""
|
16 |
+
shape = list(sample["disparity"].shape)
|
17 |
+
|
18 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
19 |
+
return sample
|
20 |
+
|
21 |
+
scale = [0, 0]
|
22 |
+
scale[0] = size[0] / shape[0]
|
23 |
+
scale[1] = size[1] / shape[1]
|
24 |
+
|
25 |
+
scale = max(scale)
|
26 |
+
|
27 |
+
shape[0] = math.ceil(scale * shape[0])
|
28 |
+
shape[1] = math.ceil(scale * shape[1])
|
29 |
+
|
30 |
+
# resize
|
31 |
+
sample["image"] = cv2.resize(
|
32 |
+
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
33 |
+
)
|
34 |
+
|
35 |
+
sample["disparity"] = cv2.resize(
|
36 |
+
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
37 |
+
)
|
38 |
+
sample["mask"] = cv2.resize(
|
39 |
+
sample["mask"].astype(np.float32),
|
40 |
+
tuple(shape[::-1]),
|
41 |
+
interpolation=cv2.INTER_NEAREST,
|
42 |
+
)
|
43 |
+
sample["mask"] = sample["mask"].astype(bool)
|
44 |
+
|
45 |
+
return tuple(shape)
|
46 |
+
|
47 |
+
|
48 |
+
class Resize(object):
|
49 |
+
"""Resize sample to given size (width, height).
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
width,
|
55 |
+
height,
|
56 |
+
resize_target=True,
|
57 |
+
keep_aspect_ratio=False,
|
58 |
+
ensure_multiple_of=1,
|
59 |
+
resize_method="lower_bound",
|
60 |
+
image_interpolation_method=cv2.INTER_AREA,
|
61 |
+
):
|
62 |
+
"""Init.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
width (int): desired output width
|
66 |
+
height (int): desired output height
|
67 |
+
resize_target (bool, optional):
|
68 |
+
True: Resize the full sample (image, mask, target).
|
69 |
+
False: Resize image only.
|
70 |
+
Defaults to True.
|
71 |
+
keep_aspect_ratio (bool, optional):
|
72 |
+
True: Keep the aspect ratio of the input sample.
|
73 |
+
Output sample might not have the given width and height, and
|
74 |
+
resize behaviour depends on the parameter 'resize_method'.
|
75 |
+
Defaults to False.
|
76 |
+
ensure_multiple_of (int, optional):
|
77 |
+
Output width and height is constrained to be multiple of this parameter.
|
78 |
+
Defaults to 1.
|
79 |
+
resize_method (str, optional):
|
80 |
+
"lower_bound": Output will be at least as large as the given size.
|
81 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
82 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
83 |
+
Defaults to "lower_bound".
|
84 |
+
"""
|
85 |
+
self.__width = width
|
86 |
+
self.__height = height
|
87 |
+
|
88 |
+
self.__resize_target = resize_target
|
89 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
90 |
+
self.__multiple_of = ensure_multiple_of
|
91 |
+
self.__resize_method = resize_method
|
92 |
+
self.__image_interpolation_method = image_interpolation_method
|
93 |
+
|
94 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
95 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
96 |
+
|
97 |
+
if max_val is not None and y > max_val:
|
98 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
99 |
+
|
100 |
+
if y < min_val:
|
101 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
102 |
+
|
103 |
+
return y
|
104 |
+
|
105 |
+
def get_size(self, width, height):
|
106 |
+
# determine new height and width
|
107 |
+
scale_height = self.__height / height
|
108 |
+
scale_width = self.__width / width
|
109 |
+
|
110 |
+
if self.__keep_aspect_ratio:
|
111 |
+
if self.__resize_method == "lower_bound":
|
112 |
+
# scale such that output size is lower bound
|
113 |
+
if scale_width > scale_height:
|
114 |
+
# fit width
|
115 |
+
scale_height = scale_width
|
116 |
+
else:
|
117 |
+
# fit height
|
118 |
+
scale_width = scale_height
|
119 |
+
elif self.__resize_method == "upper_bound":
|
120 |
+
# scale such that output size is upper bound
|
121 |
+
if scale_width < scale_height:
|
122 |
+
# fit width
|
123 |
+
scale_height = scale_width
|
124 |
+
else:
|
125 |
+
# fit height
|
126 |
+
scale_width = scale_height
|
127 |
+
elif self.__resize_method == "minimal":
|
128 |
+
# scale as least as possbile
|
129 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
130 |
+
# fit width
|
131 |
+
scale_height = scale_width
|
132 |
+
else:
|
133 |
+
# fit height
|
134 |
+
scale_width = scale_height
|
135 |
+
else:
|
136 |
+
raise ValueError(
|
137 |
+
f"resize_method {self.__resize_method} not implemented"
|
138 |
+
)
|
139 |
+
|
140 |
+
if self.__resize_method == "lower_bound":
|
141 |
+
new_height = self.constrain_to_multiple_of(
|
142 |
+
scale_height * height, min_val=self.__height
|
143 |
+
)
|
144 |
+
new_width = self.constrain_to_multiple_of(
|
145 |
+
scale_width * width, min_val=self.__width
|
146 |
+
)
|
147 |
+
elif self.__resize_method == "upper_bound":
|
148 |
+
new_height = self.constrain_to_multiple_of(
|
149 |
+
scale_height * height, max_val=self.__height
|
150 |
+
)
|
151 |
+
new_width = self.constrain_to_multiple_of(
|
152 |
+
scale_width * width, max_val=self.__width
|
153 |
+
)
|
154 |
+
elif self.__resize_method == "minimal":
|
155 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
156 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
157 |
+
else:
|
158 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
159 |
+
|
160 |
+
return (new_width, new_height)
|
161 |
+
|
162 |
+
def __call__(self, sample):
|
163 |
+
width, height = self.get_size(
|
164 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
165 |
+
)
|
166 |
+
|
167 |
+
# resize sample
|
168 |
+
sample["image"] = cv2.resize(
|
169 |
+
sample["image"],
|
170 |
+
(width, height),
|
171 |
+
interpolation=self.__image_interpolation_method,
|
172 |
+
)
|
173 |
+
|
174 |
+
if self.__resize_target:
|
175 |
+
if "disparity" in sample:
|
176 |
+
sample["disparity"] = cv2.resize(
|
177 |
+
sample["disparity"],
|
178 |
+
(width, height),
|
179 |
+
interpolation=cv2.INTER_NEAREST,
|
180 |
+
)
|
181 |
+
|
182 |
+
if "depth" in sample:
|
183 |
+
sample["depth"] = cv2.resize(
|
184 |
+
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
185 |
+
)
|
186 |
+
|
187 |
+
sample["mask"] = cv2.resize(
|
188 |
+
sample["mask"].astype(np.float32),
|
189 |
+
(width, height),
|
190 |
+
interpolation=cv2.INTER_NEAREST,
|
191 |
+
)
|
192 |
+
sample["mask"] = sample["mask"].astype(bool)
|
193 |
+
|
194 |
+
return sample
|
195 |
+
|
196 |
+
|
197 |
+
class NormalizeImage(object):
|
198 |
+
"""Normlize image by given mean and std.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self, mean, std):
|
202 |
+
self.__mean = mean
|
203 |
+
self.__std = std
|
204 |
+
|
205 |
+
def __call__(self, sample):
|
206 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
207 |
+
|
208 |
+
return sample
|
209 |
+
|
210 |
+
|
211 |
+
class PrepareForNet(object):
|
212 |
+
"""Prepare sample for usage as network input.
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self):
|
216 |
+
pass
|
217 |
+
|
218 |
+
def __call__(self, sample):
|
219 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
220 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
221 |
+
|
222 |
+
if "mask" in sample:
|
223 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
224 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
225 |
+
|
226 |
+
if "disparity" in sample:
|
227 |
+
disparity = sample["disparity"].astype(np.float32)
|
228 |
+
sample["disparity"] = np.ascontiguousarray(disparity)
|
229 |
+
|
230 |
+
if "depth" in sample:
|
231 |
+
depth = sample["depth"].astype(np.float32)
|
232 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
233 |
+
|
234 |
+
return sample
|
src/flux/annotator/midas/midas/vit.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import timm
|
4 |
+
import types
|
5 |
+
import math
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class Slice(nn.Module):
|
10 |
+
def __init__(self, start_index=1):
|
11 |
+
super(Slice, self).__init__()
|
12 |
+
self.start_index = start_index
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
return x[:, self.start_index :]
|
16 |
+
|
17 |
+
|
18 |
+
class AddReadout(nn.Module):
|
19 |
+
def __init__(self, start_index=1):
|
20 |
+
super(AddReadout, self).__init__()
|
21 |
+
self.start_index = start_index
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
if self.start_index == 2:
|
25 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
26 |
+
else:
|
27 |
+
readout = x[:, 0]
|
28 |
+
return x[:, self.start_index :] + readout.unsqueeze(1)
|
29 |
+
|
30 |
+
|
31 |
+
class ProjectReadout(nn.Module):
|
32 |
+
def __init__(self, in_features, start_index=1):
|
33 |
+
super(ProjectReadout, self).__init__()
|
34 |
+
self.start_index = start_index
|
35 |
+
|
36 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
40 |
+
features = torch.cat((x[:, self.start_index :], readout), -1)
|
41 |
+
|
42 |
+
return self.project(features)
|
43 |
+
|
44 |
+
|
45 |
+
class Transpose(nn.Module):
|
46 |
+
def __init__(self, dim0, dim1):
|
47 |
+
super(Transpose, self).__init__()
|
48 |
+
self.dim0 = dim0
|
49 |
+
self.dim1 = dim1
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x = x.transpose(self.dim0, self.dim1)
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
def forward_vit(pretrained, x):
|
57 |
+
b, c, h, w = x.shape
|
58 |
+
|
59 |
+
glob = pretrained.model.forward_flex(x)
|
60 |
+
|
61 |
+
layer_1 = pretrained.activations["1"]
|
62 |
+
layer_2 = pretrained.activations["2"]
|
63 |
+
layer_3 = pretrained.activations["3"]
|
64 |
+
layer_4 = pretrained.activations["4"]
|
65 |
+
|
66 |
+
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
67 |
+
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
68 |
+
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
69 |
+
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
70 |
+
|
71 |
+
unflatten = nn.Sequential(
|
72 |
+
nn.Unflatten(
|
73 |
+
2,
|
74 |
+
torch.Size(
|
75 |
+
[
|
76 |
+
h // pretrained.model.patch_size[1],
|
77 |
+
w // pretrained.model.patch_size[0],
|
78 |
+
]
|
79 |
+
),
|
80 |
+
)
|
81 |
+
)
|
82 |
+
|
83 |
+
if layer_1.ndim == 3:
|
84 |
+
layer_1 = unflatten(layer_1)
|
85 |
+
if layer_2.ndim == 3:
|
86 |
+
layer_2 = unflatten(layer_2)
|
87 |
+
if layer_3.ndim == 3:
|
88 |
+
layer_3 = unflatten(layer_3)
|
89 |
+
if layer_4.ndim == 3:
|
90 |
+
layer_4 = unflatten(layer_4)
|
91 |
+
|
92 |
+
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
93 |
+
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
94 |
+
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
95 |
+
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
96 |
+
|
97 |
+
return layer_1, layer_2, layer_3, layer_4
|
98 |
+
|
99 |
+
|
100 |
+
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
101 |
+
posemb_tok, posemb_grid = (
|
102 |
+
posemb[:, : self.start_index],
|
103 |
+
posemb[0, self.start_index :],
|
104 |
+
)
|
105 |
+
|
106 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
107 |
+
|
108 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
109 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
110 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
111 |
+
|
112 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
113 |
+
|
114 |
+
return posemb
|
115 |
+
|
116 |
+
|
117 |
+
def forward_flex(self, x):
|
118 |
+
b, c, h, w = x.shape
|
119 |
+
|
120 |
+
pos_embed = self._resize_pos_embed(
|
121 |
+
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
122 |
+
)
|
123 |
+
|
124 |
+
B = x.shape[0]
|
125 |
+
|
126 |
+
if hasattr(self.patch_embed, "backbone"):
|
127 |
+
x = self.patch_embed.backbone(x)
|
128 |
+
if isinstance(x, (list, tuple)):
|
129 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
130 |
+
|
131 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
132 |
+
|
133 |
+
if getattr(self, "dist_token", None) is not None:
|
134 |
+
cls_tokens = self.cls_token.expand(
|
135 |
+
B, -1, -1
|
136 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
137 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
138 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
139 |
+
else:
|
140 |
+
cls_tokens = self.cls_token.expand(
|
141 |
+
B, -1, -1
|
142 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
143 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
144 |
+
|
145 |
+
x = x + pos_embed
|
146 |
+
x = self.pos_drop(x)
|
147 |
+
|
148 |
+
for blk in self.blocks:
|
149 |
+
x = blk(x)
|
150 |
+
|
151 |
+
x = self.norm(x)
|
152 |
+
|
153 |
+
return x
|
154 |
+
|
155 |
+
|
156 |
+
activations = {}
|
157 |
+
|
158 |
+
|
159 |
+
def get_activation(name):
|
160 |
+
def hook(model, input, output):
|
161 |
+
activations[name] = output
|
162 |
+
|
163 |
+
return hook
|
164 |
+
|
165 |
+
|
166 |
+
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
167 |
+
if use_readout == "ignore":
|
168 |
+
readout_oper = [Slice(start_index)] * len(features)
|
169 |
+
elif use_readout == "add":
|
170 |
+
readout_oper = [AddReadout(start_index)] * len(features)
|
171 |
+
elif use_readout == "project":
|
172 |
+
readout_oper = [
|
173 |
+
ProjectReadout(vit_features, start_index) for out_feat in features
|
174 |
+
]
|
175 |
+
else:
|
176 |
+
assert (
|
177 |
+
False
|
178 |
+
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
179 |
+
|
180 |
+
return readout_oper
|
181 |
+
|
182 |
+
|
183 |
+
def _make_vit_b16_backbone(
|
184 |
+
model,
|
185 |
+
features=[96, 192, 384, 768],
|
186 |
+
size=[384, 384],
|
187 |
+
hooks=[2, 5, 8, 11],
|
188 |
+
vit_features=768,
|
189 |
+
use_readout="ignore",
|
190 |
+
start_index=1,
|
191 |
+
):
|
192 |
+
pretrained = nn.Module()
|
193 |
+
|
194 |
+
pretrained.model = model
|
195 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
196 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
197 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
198 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
199 |
+
|
200 |
+
pretrained.activations = activations
|
201 |
+
|
202 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
203 |
+
|
204 |
+
# 32, 48, 136, 384
|
205 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
206 |
+
readout_oper[0],
|
207 |
+
Transpose(1, 2),
|
208 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
209 |
+
nn.Conv2d(
|
210 |
+
in_channels=vit_features,
|
211 |
+
out_channels=features[0],
|
212 |
+
kernel_size=1,
|
213 |
+
stride=1,
|
214 |
+
padding=0,
|
215 |
+
),
|
216 |
+
nn.ConvTranspose2d(
|
217 |
+
in_channels=features[0],
|
218 |
+
out_channels=features[0],
|
219 |
+
kernel_size=4,
|
220 |
+
stride=4,
|
221 |
+
padding=0,
|
222 |
+
bias=True,
|
223 |
+
dilation=1,
|
224 |
+
groups=1,
|
225 |
+
),
|
226 |
+
)
|
227 |
+
|
228 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
229 |
+
readout_oper[1],
|
230 |
+
Transpose(1, 2),
|
231 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
232 |
+
nn.Conv2d(
|
233 |
+
in_channels=vit_features,
|
234 |
+
out_channels=features[1],
|
235 |
+
kernel_size=1,
|
236 |
+
stride=1,
|
237 |
+
padding=0,
|
238 |
+
),
|
239 |
+
nn.ConvTranspose2d(
|
240 |
+
in_channels=features[1],
|
241 |
+
out_channels=features[1],
|
242 |
+
kernel_size=2,
|
243 |
+
stride=2,
|
244 |
+
padding=0,
|
245 |
+
bias=True,
|
246 |
+
dilation=1,
|
247 |
+
groups=1,
|
248 |
+
),
|
249 |
+
)
|
250 |
+
|
251 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
252 |
+
readout_oper[2],
|
253 |
+
Transpose(1, 2),
|
254 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
255 |
+
nn.Conv2d(
|
256 |
+
in_channels=vit_features,
|
257 |
+
out_channels=features[2],
|
258 |
+
kernel_size=1,
|
259 |
+
stride=1,
|
260 |
+
padding=0,
|
261 |
+
),
|
262 |
+
)
|
263 |
+
|
264 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
265 |
+
readout_oper[3],
|
266 |
+
Transpose(1, 2),
|
267 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
268 |
+
nn.Conv2d(
|
269 |
+
in_channels=vit_features,
|
270 |
+
out_channels=features[3],
|
271 |
+
kernel_size=1,
|
272 |
+
stride=1,
|
273 |
+
padding=0,
|
274 |
+
),
|
275 |
+
nn.Conv2d(
|
276 |
+
in_channels=features[3],
|
277 |
+
out_channels=features[3],
|
278 |
+
kernel_size=3,
|
279 |
+
stride=2,
|
280 |
+
padding=1,
|
281 |
+
),
|
282 |
+
)
|
283 |
+
|
284 |
+
pretrained.model.start_index = start_index
|
285 |
+
pretrained.model.patch_size = [16, 16]
|
286 |
+
|
287 |
+
# We inject this function into the VisionTransformer instances so that
|
288 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
289 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
290 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
291 |
+
_resize_pos_embed, pretrained.model
|
292 |
+
)
|
293 |
+
|
294 |
+
return pretrained
|
295 |
+
|
296 |
+
|
297 |
+
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
298 |
+
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
299 |
+
|
300 |
+
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
301 |
+
return _make_vit_b16_backbone(
|
302 |
+
model,
|
303 |
+
features=[256, 512, 1024, 1024],
|
304 |
+
hooks=hooks,
|
305 |
+
vit_features=1024,
|
306 |
+
use_readout=use_readout,
|
307 |
+
)
|
308 |
+
|
309 |
+
|
310 |
+
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
311 |
+
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
312 |
+
|
313 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
314 |
+
return _make_vit_b16_backbone(
|
315 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
316 |
+
)
|
317 |
+
|
318 |
+
|
319 |
+
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
|
320 |
+
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
321 |
+
|
322 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
323 |
+
return _make_vit_b16_backbone(
|
324 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
325 |
+
)
|
326 |
+
|
327 |
+
|
328 |
+
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
|
329 |
+
model = timm.create_model(
|
330 |
+
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
331 |
+
)
|
332 |
+
|
333 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
334 |
+
return _make_vit_b16_backbone(
|
335 |
+
model,
|
336 |
+
features=[96, 192, 384, 768],
|
337 |
+
hooks=hooks,
|
338 |
+
use_readout=use_readout,
|
339 |
+
start_index=2,
|
340 |
+
)
|
341 |
+
|
342 |
+
|
343 |
+
def _make_vit_b_rn50_backbone(
|
344 |
+
model,
|
345 |
+
features=[256, 512, 768, 768],
|
346 |
+
size=[384, 384],
|
347 |
+
hooks=[0, 1, 8, 11],
|
348 |
+
vit_features=768,
|
349 |
+
use_vit_only=False,
|
350 |
+
use_readout="ignore",
|
351 |
+
start_index=1,
|
352 |
+
):
|
353 |
+
pretrained = nn.Module()
|
354 |
+
|
355 |
+
pretrained.model = model
|
356 |
+
|
357 |
+
if use_vit_only == True:
|
358 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
359 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
360 |
+
else:
|
361 |
+
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
362 |
+
get_activation("1")
|
363 |
+
)
|
364 |
+
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
365 |
+
get_activation("2")
|
366 |
+
)
|
367 |
+
|
368 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
369 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
370 |
+
|
371 |
+
pretrained.activations = activations
|
372 |
+
|
373 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
374 |
+
|
375 |
+
if use_vit_only == True:
|
376 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
377 |
+
readout_oper[0],
|
378 |
+
Transpose(1, 2),
|
379 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
380 |
+
nn.Conv2d(
|
381 |
+
in_channels=vit_features,
|
382 |
+
out_channels=features[0],
|
383 |
+
kernel_size=1,
|
384 |
+
stride=1,
|
385 |
+
padding=0,
|
386 |
+
),
|
387 |
+
nn.ConvTranspose2d(
|
388 |
+
in_channels=features[0],
|
389 |
+
out_channels=features[0],
|
390 |
+
kernel_size=4,
|
391 |
+
stride=4,
|
392 |
+
padding=0,
|
393 |
+
bias=True,
|
394 |
+
dilation=1,
|
395 |
+
groups=1,
|
396 |
+
),
|
397 |
+
)
|
398 |
+
|
399 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
400 |
+
readout_oper[1],
|
401 |
+
Transpose(1, 2),
|
402 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
403 |
+
nn.Conv2d(
|
404 |
+
in_channels=vit_features,
|
405 |
+
out_channels=features[1],
|
406 |
+
kernel_size=1,
|
407 |
+
stride=1,
|
408 |
+
padding=0,
|
409 |
+
),
|
410 |
+
nn.ConvTranspose2d(
|
411 |
+
in_channels=features[1],
|
412 |
+
out_channels=features[1],
|
413 |
+
kernel_size=2,
|
414 |
+
stride=2,
|
415 |
+
padding=0,
|
416 |
+
bias=True,
|
417 |
+
dilation=1,
|
418 |
+
groups=1,
|
419 |
+
),
|
420 |
+
)
|
421 |
+
else:
|
422 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
423 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
424 |
+
)
|
425 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
426 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
427 |
+
)
|
428 |
+
|
429 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
430 |
+
readout_oper[2],
|
431 |
+
Transpose(1, 2),
|
432 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
433 |
+
nn.Conv2d(
|
434 |
+
in_channels=vit_features,
|
435 |
+
out_channels=features[2],
|
436 |
+
kernel_size=1,
|
437 |
+
stride=1,
|
438 |
+
padding=0,
|
439 |
+
),
|
440 |
+
)
|
441 |
+
|
442 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
443 |
+
readout_oper[3],
|
444 |
+
Transpose(1, 2),
|
445 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
446 |
+
nn.Conv2d(
|
447 |
+
in_channels=vit_features,
|
448 |
+
out_channels=features[3],
|
449 |
+
kernel_size=1,
|
450 |
+
stride=1,
|
451 |
+
padding=0,
|
452 |
+
),
|
453 |
+
nn.Conv2d(
|
454 |
+
in_channels=features[3],
|
455 |
+
out_channels=features[3],
|
456 |
+
kernel_size=3,
|
457 |
+
stride=2,
|
458 |
+
padding=1,
|
459 |
+
),
|
460 |
+
)
|
461 |
+
|
462 |
+
pretrained.model.start_index = start_index
|
463 |
+
pretrained.model.patch_size = [16, 16]
|
464 |
+
|
465 |
+
# We inject this function into the VisionTransformer instances so that
|
466 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
467 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
468 |
+
|
469 |
+
# We inject this function into the VisionTransformer instances so that
|
470 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
471 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
472 |
+
_resize_pos_embed, pretrained.model
|
473 |
+
)
|
474 |
+
|
475 |
+
return pretrained
|
476 |
+
|
477 |
+
|
478 |
+
def _make_pretrained_vitb_rn50_384(
|
479 |
+
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
480 |
+
):
|
481 |
+
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
482 |
+
|
483 |
+
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
484 |
+
return _make_vit_b_rn50_backbone(
|
485 |
+
model,
|
486 |
+
features=[256, 512, 768, 768],
|
487 |
+
size=[384, 384],
|
488 |
+
hooks=hooks,
|
489 |
+
use_vit_only=use_vit_only,
|
490 |
+
use_readout=use_readout,
|
491 |
+
)
|
src/flux/annotator/midas/utils.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utils for monoDepth."""
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def read_pfm(path):
|
10 |
+
"""Read pfm file.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
path (str): path to file
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
tuple: (data, scale)
|
17 |
+
"""
|
18 |
+
with open(path, "rb") as file:
|
19 |
+
|
20 |
+
color = None
|
21 |
+
width = None
|
22 |
+
height = None
|
23 |
+
scale = None
|
24 |
+
endian = None
|
25 |
+
|
26 |
+
header = file.readline().rstrip()
|
27 |
+
if header.decode("ascii") == "PF":
|
28 |
+
color = True
|
29 |
+
elif header.decode("ascii") == "Pf":
|
30 |
+
color = False
|
31 |
+
else:
|
32 |
+
raise Exception("Not a PFM file: " + path)
|
33 |
+
|
34 |
+
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
|
35 |
+
if dim_match:
|
36 |
+
width, height = list(map(int, dim_match.groups()))
|
37 |
+
else:
|
38 |
+
raise Exception("Malformed PFM header.")
|
39 |
+
|
40 |
+
scale = float(file.readline().decode("ascii").rstrip())
|
41 |
+
if scale < 0:
|
42 |
+
# little-endian
|
43 |
+
endian = "<"
|
44 |
+
scale = -scale
|
45 |
+
else:
|
46 |
+
# big-endian
|
47 |
+
endian = ">"
|
48 |
+
|
49 |
+
data = np.fromfile(file, endian + "f")
|
50 |
+
shape = (height, width, 3) if color else (height, width)
|
51 |
+
|
52 |
+
data = np.reshape(data, shape)
|
53 |
+
data = np.flipud(data)
|
54 |
+
|
55 |
+
return data, scale
|
56 |
+
|
57 |
+
|
58 |
+
def write_pfm(path, image, scale=1):
|
59 |
+
"""Write pfm file.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
path (str): pathto file
|
63 |
+
image (array): data
|
64 |
+
scale (int, optional): Scale. Defaults to 1.
|
65 |
+
"""
|
66 |
+
|
67 |
+
with open(path, "wb") as file:
|
68 |
+
color = None
|
69 |
+
|
70 |
+
if image.dtype.name != "float32":
|
71 |
+
raise Exception("Image dtype must be float32.")
|
72 |
+
|
73 |
+
image = np.flipud(image)
|
74 |
+
|
75 |
+
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
76 |
+
color = True
|
77 |
+
elif (
|
78 |
+
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
|
79 |
+
): # greyscale
|
80 |
+
color = False
|
81 |
+
else:
|
82 |
+
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
|
83 |
+
|
84 |
+
file.write("PF\n" if color else "Pf\n".encode())
|
85 |
+
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
|
86 |
+
|
87 |
+
endian = image.dtype.byteorder
|
88 |
+
|
89 |
+
if endian == "<" or endian == "=" and sys.byteorder == "little":
|
90 |
+
scale = -scale
|
91 |
+
|
92 |
+
file.write("%f\n".encode() % scale)
|
93 |
+
|
94 |
+
image.tofile(file)
|
95 |
+
|
96 |
+
|
97 |
+
def read_image(path):
|
98 |
+
"""Read image and output RGB image (0-1).
|
99 |
+
|
100 |
+
Args:
|
101 |
+
path (str): path to file
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
array: RGB image (0-1)
|
105 |
+
"""
|
106 |
+
img = cv2.imread(path)
|
107 |
+
|
108 |
+
if img.ndim == 2:
|
109 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
110 |
+
|
111 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
112 |
+
|
113 |
+
return img
|
114 |
+
|
115 |
+
|
116 |
+
def resize_image(img):
|
117 |
+
"""Resize image and make it fit for network.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
img (array): image
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
tensor: data ready for network
|
124 |
+
"""
|
125 |
+
height_orig = img.shape[0]
|
126 |
+
width_orig = img.shape[1]
|
127 |
+
|
128 |
+
if width_orig > height_orig:
|
129 |
+
scale = width_orig / 384
|
130 |
+
else:
|
131 |
+
scale = height_orig / 384
|
132 |
+
|
133 |
+
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
|
134 |
+
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
|
135 |
+
|
136 |
+
img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
|
137 |
+
|
138 |
+
img_resized = (
|
139 |
+
torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
|
140 |
+
)
|
141 |
+
img_resized = img_resized.unsqueeze(0)
|
142 |
+
|
143 |
+
return img_resized
|
144 |
+
|
145 |
+
|
146 |
+
def resize_depth(depth, width, height):
|
147 |
+
"""Resize depth map and bring to CPU (numpy).
|
148 |
+
|
149 |
+
Args:
|
150 |
+
depth (tensor): depth
|
151 |
+
width (int): image width
|
152 |
+
height (int): image height
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
array: processed depth
|
156 |
+
"""
|
157 |
+
depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
|
158 |
+
|
159 |
+
depth_resized = cv2.resize(
|
160 |
+
depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
|
161 |
+
)
|
162 |
+
|
163 |
+
return depth_resized
|
164 |
+
|
165 |
+
def write_depth(path, depth, bits=1):
|
166 |
+
"""Write depth map to pfm and png file.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
path (str): filepath without extension
|
170 |
+
depth (array): depth
|
171 |
+
"""
|
172 |
+
write_pfm(path + ".pfm", depth.astype(np.float32))
|
173 |
+
|
174 |
+
depth_min = depth.min()
|
175 |
+
depth_max = depth.max()
|
176 |
+
|
177 |
+
max_val = (2**(8*bits))-1
|
178 |
+
|
179 |
+
if depth_max - depth_min > np.finfo("float").eps:
|
180 |
+
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
181 |
+
else:
|
182 |
+
out = np.zeros(depth.shape, dtype=depth.type)
|
183 |
+
|
184 |
+
if bits == 1:
|
185 |
+
cv2.imwrite(path + ".png", out.astype("uint8"))
|
186 |
+
elif bits == 2:
|
187 |
+
cv2.imwrite(path + ".png", out.astype("uint16"))
|
188 |
+
|
189 |
+
return
|
src/flux/annotator/mlsd/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2021-present NAVER Corp.
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
src/flux/annotator/mlsd/__init__.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MLSD Line Detection
|
2 |
+
# From https://github.com/navervision/mlsd
|
3 |
+
# Apache-2.0 license
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import os
|
9 |
+
|
10 |
+
from einops import rearrange
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny
|
13 |
+
from .models.mbv2_mlsd_large import MobileV2_MLSD_Large
|
14 |
+
from .utils import pred_lines
|
15 |
+
|
16 |
+
from ...annotator.util import annotator_ckpts_path
|
17 |
+
|
18 |
+
|
19 |
+
class MLSDdetector:
|
20 |
+
def __init__(self):
|
21 |
+
model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth")
|
22 |
+
if not os.path.exists(model_path):
|
23 |
+
model_path = hf_hub_download("lllyasviel/Annotators", "mlsd_large_512_fp32.pth")
|
24 |
+
model = MobileV2_MLSD_Large()
|
25 |
+
model.load_state_dict(torch.load(model_path), strict=True)
|
26 |
+
self.model = model.cuda().eval()
|
27 |
+
|
28 |
+
def __call__(self, input_image, thr_v, thr_d):
|
29 |
+
assert input_image.ndim == 3
|
30 |
+
img = input_image
|
31 |
+
img_output = np.zeros_like(img)
|
32 |
+
try:
|
33 |
+
with torch.no_grad():
|
34 |
+
lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d)
|
35 |
+
for line in lines:
|
36 |
+
x_start, y_start, x_end, y_end = [int(val) for val in line]
|
37 |
+
cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1)
|
38 |
+
except Exception as e:
|
39 |
+
pass
|
40 |
+
return img_output[:, :, 0]
|
src/flux/annotator/mlsd/models/mbv2_mlsd_large.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.utils.model_zoo as model_zoo
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class BlockTypeA(nn.Module):
|
10 |
+
def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
|
11 |
+
super(BlockTypeA, self).__init__()
|
12 |
+
self.conv1 = nn.Sequential(
|
13 |
+
nn.Conv2d(in_c2, out_c2, kernel_size=1),
|
14 |
+
nn.BatchNorm2d(out_c2),
|
15 |
+
nn.ReLU(inplace=True)
|
16 |
+
)
|
17 |
+
self.conv2 = nn.Sequential(
|
18 |
+
nn.Conv2d(in_c1, out_c1, kernel_size=1),
|
19 |
+
nn.BatchNorm2d(out_c1),
|
20 |
+
nn.ReLU(inplace=True)
|
21 |
+
)
|
22 |
+
self.upscale = upscale
|
23 |
+
|
24 |
+
def forward(self, a, b):
|
25 |
+
b = self.conv1(b)
|
26 |
+
a = self.conv2(a)
|
27 |
+
if self.upscale:
|
28 |
+
b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
|
29 |
+
return torch.cat((a, b), dim=1)
|
30 |
+
|
31 |
+
|
32 |
+
class BlockTypeB(nn.Module):
|
33 |
+
def __init__(self, in_c, out_c):
|
34 |
+
super(BlockTypeB, self).__init__()
|
35 |
+
self.conv1 = nn.Sequential(
|
36 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
37 |
+
nn.BatchNorm2d(in_c),
|
38 |
+
nn.ReLU()
|
39 |
+
)
|
40 |
+
self.conv2 = nn.Sequential(
|
41 |
+
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
|
42 |
+
nn.BatchNorm2d(out_c),
|
43 |
+
nn.ReLU()
|
44 |
+
)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
x = self.conv1(x) + x
|
48 |
+
x = self.conv2(x)
|
49 |
+
return x
|
50 |
+
|
51 |
+
class BlockTypeC(nn.Module):
|
52 |
+
def __init__(self, in_c, out_c):
|
53 |
+
super(BlockTypeC, self).__init__()
|
54 |
+
self.conv1 = nn.Sequential(
|
55 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
|
56 |
+
nn.BatchNorm2d(in_c),
|
57 |
+
nn.ReLU()
|
58 |
+
)
|
59 |
+
self.conv2 = nn.Sequential(
|
60 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
61 |
+
nn.BatchNorm2d(in_c),
|
62 |
+
nn.ReLU()
|
63 |
+
)
|
64 |
+
self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
x = self.conv1(x)
|
68 |
+
x = self.conv2(x)
|
69 |
+
x = self.conv3(x)
|
70 |
+
return x
|
71 |
+
|
72 |
+
def _make_divisible(v, divisor, min_value=None):
|
73 |
+
"""
|
74 |
+
This function is taken from the original tf repo.
|
75 |
+
It ensures that all layers have a channel number that is divisible by 8
|
76 |
+
It can be seen here:
|
77 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
78 |
+
:param v:
|
79 |
+
:param divisor:
|
80 |
+
:param min_value:
|
81 |
+
:return:
|
82 |
+
"""
|
83 |
+
if min_value is None:
|
84 |
+
min_value = divisor
|
85 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
86 |
+
# Make sure that round down does not go down by more than 10%.
|
87 |
+
if new_v < 0.9 * v:
|
88 |
+
new_v += divisor
|
89 |
+
return new_v
|
90 |
+
|
91 |
+
|
92 |
+
class ConvBNReLU(nn.Sequential):
|
93 |
+
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
94 |
+
self.channel_pad = out_planes - in_planes
|
95 |
+
self.stride = stride
|
96 |
+
#padding = (kernel_size - 1) // 2
|
97 |
+
|
98 |
+
# TFLite uses slightly different padding than PyTorch
|
99 |
+
if stride == 2:
|
100 |
+
padding = 0
|
101 |
+
else:
|
102 |
+
padding = (kernel_size - 1) // 2
|
103 |
+
|
104 |
+
super(ConvBNReLU, self).__init__(
|
105 |
+
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
106 |
+
nn.BatchNorm2d(out_planes),
|
107 |
+
nn.ReLU6(inplace=True)
|
108 |
+
)
|
109 |
+
self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
|
110 |
+
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
# TFLite uses different padding
|
114 |
+
if self.stride == 2:
|
115 |
+
x = F.pad(x, (0, 1, 0, 1), "constant", 0)
|
116 |
+
#print(x.shape)
|
117 |
+
|
118 |
+
for module in self:
|
119 |
+
if not isinstance(module, nn.MaxPool2d):
|
120 |
+
x = module(x)
|
121 |
+
return x
|
122 |
+
|
123 |
+
|
124 |
+
class InvertedResidual(nn.Module):
|
125 |
+
def __init__(self, inp, oup, stride, expand_ratio):
|
126 |
+
super(InvertedResidual, self).__init__()
|
127 |
+
self.stride = stride
|
128 |
+
assert stride in [1, 2]
|
129 |
+
|
130 |
+
hidden_dim = int(round(inp * expand_ratio))
|
131 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
132 |
+
|
133 |
+
layers = []
|
134 |
+
if expand_ratio != 1:
|
135 |
+
# pw
|
136 |
+
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
137 |
+
layers.extend([
|
138 |
+
# dw
|
139 |
+
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
140 |
+
# pw-linear
|
141 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
142 |
+
nn.BatchNorm2d(oup),
|
143 |
+
])
|
144 |
+
self.conv = nn.Sequential(*layers)
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
if self.use_res_connect:
|
148 |
+
return x + self.conv(x)
|
149 |
+
else:
|
150 |
+
return self.conv(x)
|
151 |
+
|
152 |
+
|
153 |
+
class MobileNetV2(nn.Module):
|
154 |
+
def __init__(self, pretrained=True):
|
155 |
+
"""
|
156 |
+
MobileNet V2 main class
|
157 |
+
Args:
|
158 |
+
num_classes (int): Number of classes
|
159 |
+
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
160 |
+
inverted_residual_setting: Network structure
|
161 |
+
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
162 |
+
Set to 1 to turn off rounding
|
163 |
+
block: Module specifying inverted residual building block for mobilenet
|
164 |
+
"""
|
165 |
+
super(MobileNetV2, self).__init__()
|
166 |
+
|
167 |
+
block = InvertedResidual
|
168 |
+
input_channel = 32
|
169 |
+
last_channel = 1280
|
170 |
+
width_mult = 1.0
|
171 |
+
round_nearest = 8
|
172 |
+
|
173 |
+
inverted_residual_setting = [
|
174 |
+
# t, c, n, s
|
175 |
+
[1, 16, 1, 1],
|
176 |
+
[6, 24, 2, 2],
|
177 |
+
[6, 32, 3, 2],
|
178 |
+
[6, 64, 4, 2],
|
179 |
+
[6, 96, 3, 1],
|
180 |
+
#[6, 160, 3, 2],
|
181 |
+
#[6, 320, 1, 1],
|
182 |
+
]
|
183 |
+
|
184 |
+
# only check the first element, assuming user knows t,c,n,s are required
|
185 |
+
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
186 |
+
raise ValueError("inverted_residual_setting should be non-empty "
|
187 |
+
"or a 4-element list, got {}".format(inverted_residual_setting))
|
188 |
+
|
189 |
+
# building first layer
|
190 |
+
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
191 |
+
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
192 |
+
features = [ConvBNReLU(4, input_channel, stride=2)]
|
193 |
+
# building inverted residual blocks
|
194 |
+
for t, c, n, s in inverted_residual_setting:
|
195 |
+
output_channel = _make_divisible(c * width_mult, round_nearest)
|
196 |
+
for i in range(n):
|
197 |
+
stride = s if i == 0 else 1
|
198 |
+
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
199 |
+
input_channel = output_channel
|
200 |
+
|
201 |
+
self.features = nn.Sequential(*features)
|
202 |
+
self.fpn_selected = [1, 3, 6, 10, 13]
|
203 |
+
# weight initialization
|
204 |
+
for m in self.modules():
|
205 |
+
if isinstance(m, nn.Conv2d):
|
206 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
207 |
+
if m.bias is not None:
|
208 |
+
nn.init.zeros_(m.bias)
|
209 |
+
elif isinstance(m, nn.BatchNorm2d):
|
210 |
+
nn.init.ones_(m.weight)
|
211 |
+
nn.init.zeros_(m.bias)
|
212 |
+
elif isinstance(m, nn.Linear):
|
213 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
214 |
+
nn.init.zeros_(m.bias)
|
215 |
+
if pretrained:
|
216 |
+
self._load_pretrained_model()
|
217 |
+
|
218 |
+
def _forward_impl(self, x):
|
219 |
+
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
220 |
+
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
221 |
+
fpn_features = []
|
222 |
+
for i, f in enumerate(self.features):
|
223 |
+
if i > self.fpn_selected[-1]:
|
224 |
+
break
|
225 |
+
x = f(x)
|
226 |
+
if i in self.fpn_selected:
|
227 |
+
fpn_features.append(x)
|
228 |
+
|
229 |
+
c1, c2, c3, c4, c5 = fpn_features
|
230 |
+
return c1, c2, c3, c4, c5
|
231 |
+
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
return self._forward_impl(x)
|
235 |
+
|
236 |
+
def _load_pretrained_model(self):
|
237 |
+
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
|
238 |
+
model_dict = {}
|
239 |
+
state_dict = self.state_dict()
|
240 |
+
for k, v in pretrain_dict.items():
|
241 |
+
if k in state_dict:
|
242 |
+
model_dict[k] = v
|
243 |
+
state_dict.update(model_dict)
|
244 |
+
self.load_state_dict(state_dict)
|
245 |
+
|
246 |
+
|
247 |
+
class MobileV2_MLSD_Large(nn.Module):
|
248 |
+
def __init__(self):
|
249 |
+
super(MobileV2_MLSD_Large, self).__init__()
|
250 |
+
|
251 |
+
self.backbone = MobileNetV2(pretrained=False)
|
252 |
+
## A, B
|
253 |
+
self.block15 = BlockTypeA(in_c1= 64, in_c2= 96,
|
254 |
+
out_c1= 64, out_c2=64,
|
255 |
+
upscale=False)
|
256 |
+
self.block16 = BlockTypeB(128, 64)
|
257 |
+
|
258 |
+
## A, B
|
259 |
+
self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64,
|
260 |
+
out_c1= 64, out_c2= 64)
|
261 |
+
self.block18 = BlockTypeB(128, 64)
|
262 |
+
|
263 |
+
## A, B
|
264 |
+
self.block19 = BlockTypeA(in_c1=24, in_c2=64,
|
265 |
+
out_c1=64, out_c2=64)
|
266 |
+
self.block20 = BlockTypeB(128, 64)
|
267 |
+
|
268 |
+
## A, B, C
|
269 |
+
self.block21 = BlockTypeA(in_c1=16, in_c2=64,
|
270 |
+
out_c1=64, out_c2=64)
|
271 |
+
self.block22 = BlockTypeB(128, 64)
|
272 |
+
|
273 |
+
self.block23 = BlockTypeC(64, 16)
|
274 |
+
|
275 |
+
def forward(self, x):
|
276 |
+
c1, c2, c3, c4, c5 = self.backbone(x)
|
277 |
+
|
278 |
+
x = self.block15(c4, c5)
|
279 |
+
x = self.block16(x)
|
280 |
+
|
281 |
+
x = self.block17(c3, x)
|
282 |
+
x = self.block18(x)
|
283 |
+
|
284 |
+
x = self.block19(c2, x)
|
285 |
+
x = self.block20(x)
|
286 |
+
|
287 |
+
x = self.block21(c1, x)
|
288 |
+
x = self.block22(x)
|
289 |
+
x = self.block23(x)
|
290 |
+
x = x[:, 7:, :, :]
|
291 |
+
|
292 |
+
return x
|
src/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.utils.model_zoo as model_zoo
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class BlockTypeA(nn.Module):
|
10 |
+
def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
|
11 |
+
super(BlockTypeA, self).__init__()
|
12 |
+
self.conv1 = nn.Sequential(
|
13 |
+
nn.Conv2d(in_c2, out_c2, kernel_size=1),
|
14 |
+
nn.BatchNorm2d(out_c2),
|
15 |
+
nn.ReLU(inplace=True)
|
16 |
+
)
|
17 |
+
self.conv2 = nn.Sequential(
|
18 |
+
nn.Conv2d(in_c1, out_c1, kernel_size=1),
|
19 |
+
nn.BatchNorm2d(out_c1),
|
20 |
+
nn.ReLU(inplace=True)
|
21 |
+
)
|
22 |
+
self.upscale = upscale
|
23 |
+
|
24 |
+
def forward(self, a, b):
|
25 |
+
b = self.conv1(b)
|
26 |
+
a = self.conv2(a)
|
27 |
+
b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
|
28 |
+
return torch.cat((a, b), dim=1)
|
29 |
+
|
30 |
+
|
31 |
+
class BlockTypeB(nn.Module):
|
32 |
+
def __init__(self, in_c, out_c):
|
33 |
+
super(BlockTypeB, self).__init__()
|
34 |
+
self.conv1 = nn.Sequential(
|
35 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
36 |
+
nn.BatchNorm2d(in_c),
|
37 |
+
nn.ReLU()
|
38 |
+
)
|
39 |
+
self.conv2 = nn.Sequential(
|
40 |
+
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
|
41 |
+
nn.BatchNorm2d(out_c),
|
42 |
+
nn.ReLU()
|
43 |
+
)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
x = self.conv1(x) + x
|
47 |
+
x = self.conv2(x)
|
48 |
+
return x
|
49 |
+
|
50 |
+
class BlockTypeC(nn.Module):
|
51 |
+
def __init__(self, in_c, out_c):
|
52 |
+
super(BlockTypeC, self).__init__()
|
53 |
+
self.conv1 = nn.Sequential(
|
54 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
|
55 |
+
nn.BatchNorm2d(in_c),
|
56 |
+
nn.ReLU()
|
57 |
+
)
|
58 |
+
self.conv2 = nn.Sequential(
|
59 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
60 |
+
nn.BatchNorm2d(in_c),
|
61 |
+
nn.ReLU()
|
62 |
+
)
|
63 |
+
self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
x = self.conv1(x)
|
67 |
+
x = self.conv2(x)
|
68 |
+
x = self.conv3(x)
|
69 |
+
return x
|
70 |
+
|
71 |
+
def _make_divisible(v, divisor, min_value=None):
|
72 |
+
"""
|
73 |
+
This function is taken from the original tf repo.
|
74 |
+
It ensures that all layers have a channel number that is divisible by 8
|
75 |
+
It can be seen here:
|
76 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
77 |
+
:param v:
|
78 |
+
:param divisor:
|
79 |
+
:param min_value:
|
80 |
+
:return:
|
81 |
+
"""
|
82 |
+
if min_value is None:
|
83 |
+
min_value = divisor
|
84 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
85 |
+
# Make sure that round down does not go down by more than 10%.
|
86 |
+
if new_v < 0.9 * v:
|
87 |
+
new_v += divisor
|
88 |
+
return new_v
|
89 |
+
|
90 |
+
|
91 |
+
class ConvBNReLU(nn.Sequential):
|
92 |
+
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
93 |
+
self.channel_pad = out_planes - in_planes
|
94 |
+
self.stride = stride
|
95 |
+
#padding = (kernel_size - 1) // 2
|
96 |
+
|
97 |
+
# TFLite uses slightly different padding than PyTorch
|
98 |
+
if stride == 2:
|
99 |
+
padding = 0
|
100 |
+
else:
|
101 |
+
padding = (kernel_size - 1) // 2
|
102 |
+
|
103 |
+
super(ConvBNReLU, self).__init__(
|
104 |
+
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
105 |
+
nn.BatchNorm2d(out_planes),
|
106 |
+
nn.ReLU6(inplace=True)
|
107 |
+
)
|
108 |
+
self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
|
109 |
+
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
# TFLite uses different padding
|
113 |
+
if self.stride == 2:
|
114 |
+
x = F.pad(x, (0, 1, 0, 1), "constant", 0)
|
115 |
+
#print(x.shape)
|
116 |
+
|
117 |
+
for module in self:
|
118 |
+
if not isinstance(module, nn.MaxPool2d):
|
119 |
+
x = module(x)
|
120 |
+
return x
|
121 |
+
|
122 |
+
|
123 |
+
class InvertedResidual(nn.Module):
|
124 |
+
def __init__(self, inp, oup, stride, expand_ratio):
|
125 |
+
super(InvertedResidual, self).__init__()
|
126 |
+
self.stride = stride
|
127 |
+
assert stride in [1, 2]
|
128 |
+
|
129 |
+
hidden_dim = int(round(inp * expand_ratio))
|
130 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
131 |
+
|
132 |
+
layers = []
|
133 |
+
if expand_ratio != 1:
|
134 |
+
# pw
|
135 |
+
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
136 |
+
layers.extend([
|
137 |
+
# dw
|
138 |
+
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
139 |
+
# pw-linear
|
140 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
141 |
+
nn.BatchNorm2d(oup),
|
142 |
+
])
|
143 |
+
self.conv = nn.Sequential(*layers)
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
if self.use_res_connect:
|
147 |
+
return x + self.conv(x)
|
148 |
+
else:
|
149 |
+
return self.conv(x)
|
150 |
+
|
151 |
+
|
152 |
+
class MobileNetV2(nn.Module):
|
153 |
+
def __init__(self, pretrained=True):
|
154 |
+
"""
|
155 |
+
MobileNet V2 main class
|
156 |
+
Args:
|
157 |
+
num_classes (int): Number of classes
|
158 |
+
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
159 |
+
inverted_residual_setting: Network structure
|
160 |
+
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
161 |
+
Set to 1 to turn off rounding
|
162 |
+
block: Module specifying inverted residual building block for mobilenet
|
163 |
+
"""
|
164 |
+
super(MobileNetV2, self).__init__()
|
165 |
+
|
166 |
+
block = InvertedResidual
|
167 |
+
input_channel = 32
|
168 |
+
last_channel = 1280
|
169 |
+
width_mult = 1.0
|
170 |
+
round_nearest = 8
|
171 |
+
|
172 |
+
inverted_residual_setting = [
|
173 |
+
# t, c, n, s
|
174 |
+
[1, 16, 1, 1],
|
175 |
+
[6, 24, 2, 2],
|
176 |
+
[6, 32, 3, 2],
|
177 |
+
[6, 64, 4, 2],
|
178 |
+
#[6, 96, 3, 1],
|
179 |
+
#[6, 160, 3, 2],
|
180 |
+
#[6, 320, 1, 1],
|
181 |
+
]
|
182 |
+
|
183 |
+
# only check the first element, assuming user knows t,c,n,s are required
|
184 |
+
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
185 |
+
raise ValueError("inverted_residual_setting should be non-empty "
|
186 |
+
"or a 4-element list, got {}".format(inverted_residual_setting))
|
187 |
+
|
188 |
+
# building first layer
|
189 |
+
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
190 |
+
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
191 |
+
features = [ConvBNReLU(4, input_channel, stride=2)]
|
192 |
+
# building inverted residual blocks
|
193 |
+
for t, c, n, s in inverted_residual_setting:
|
194 |
+
output_channel = _make_divisible(c * width_mult, round_nearest)
|
195 |
+
for i in range(n):
|
196 |
+
stride = s if i == 0 else 1
|
197 |
+
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
198 |
+
input_channel = output_channel
|
199 |
+
self.features = nn.Sequential(*features)
|
200 |
+
|
201 |
+
self.fpn_selected = [3, 6, 10]
|
202 |
+
# weight initialization
|
203 |
+
for m in self.modules():
|
204 |
+
if isinstance(m, nn.Conv2d):
|
205 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
206 |
+
if m.bias is not None:
|
207 |
+
nn.init.zeros_(m.bias)
|
208 |
+
elif isinstance(m, nn.BatchNorm2d):
|
209 |
+
nn.init.ones_(m.weight)
|
210 |
+
nn.init.zeros_(m.bias)
|
211 |
+
elif isinstance(m, nn.Linear):
|
212 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
213 |
+
nn.init.zeros_(m.bias)
|
214 |
+
|
215 |
+
#if pretrained:
|
216 |
+
# self._load_pretrained_model()
|
217 |
+
|
218 |
+
def _forward_impl(self, x):
|
219 |
+
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
220 |
+
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
221 |
+
fpn_features = []
|
222 |
+
for i, f in enumerate(self.features):
|
223 |
+
if i > self.fpn_selected[-1]:
|
224 |
+
break
|
225 |
+
x = f(x)
|
226 |
+
if i in self.fpn_selected:
|
227 |
+
fpn_features.append(x)
|
228 |
+
|
229 |
+
c2, c3, c4 = fpn_features
|
230 |
+
return c2, c3, c4
|
231 |
+
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
return self._forward_impl(x)
|
235 |
+
|
236 |
+
def _load_pretrained_model(self):
|
237 |
+
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
|
238 |
+
model_dict = {}
|
239 |
+
state_dict = self.state_dict()
|
240 |
+
for k, v in pretrain_dict.items():
|
241 |
+
if k in state_dict:
|
242 |
+
model_dict[k] = v
|
243 |
+
state_dict.update(model_dict)
|
244 |
+
self.load_state_dict(state_dict)
|
245 |
+
|
246 |
+
|
247 |
+
class MobileV2_MLSD_Tiny(nn.Module):
|
248 |
+
def __init__(self):
|
249 |
+
super(MobileV2_MLSD_Tiny, self).__init__()
|
250 |
+
|
251 |
+
self.backbone = MobileNetV2(pretrained=True)
|
252 |
+
|
253 |
+
self.block12 = BlockTypeA(in_c1= 32, in_c2= 64,
|
254 |
+
out_c1= 64, out_c2=64)
|
255 |
+
self.block13 = BlockTypeB(128, 64)
|
256 |
+
|
257 |
+
self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64,
|
258 |
+
out_c1= 32, out_c2= 32)
|
259 |
+
self.block15 = BlockTypeB(64, 64)
|
260 |
+
|
261 |
+
self.block16 = BlockTypeC(64, 16)
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
c2, c3, c4 = self.backbone(x)
|
265 |
+
|
266 |
+
x = self.block12(c3, c4)
|
267 |
+
x = self.block13(x)
|
268 |
+
x = self.block14(c2, x)
|
269 |
+
x = self.block15(x)
|
270 |
+
x = self.block16(x)
|
271 |
+
x = x[:, 7:, :, :]
|
272 |
+
#print(x.shape)
|
273 |
+
x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True)
|
274 |
+
|
275 |
+
return x
|
src/flux/annotator/mlsd/utils.py
ADDED
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
modified by lihaoweicv
|
3 |
+
pytorch version
|
4 |
+
'''
|
5 |
+
|
6 |
+
'''
|
7 |
+
M-LSD
|
8 |
+
Copyright 2021-present NAVER Corp.
|
9 |
+
Apache License v2.0
|
10 |
+
'''
|
11 |
+
|
12 |
+
import os
|
13 |
+
import numpy as np
|
14 |
+
import cv2
|
15 |
+
import torch
|
16 |
+
from torch.nn import functional as F
|
17 |
+
|
18 |
+
|
19 |
+
def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
|
20 |
+
'''
|
21 |
+
tpMap:
|
22 |
+
center: tpMap[1, 0, :, :]
|
23 |
+
displacement: tpMap[1, 1:5, :, :]
|
24 |
+
'''
|
25 |
+
b, c, h, w = tpMap.shape
|
26 |
+
assert b==1, 'only support bsize==1'
|
27 |
+
displacement = tpMap[:, 1:5, :, :][0]
|
28 |
+
center = tpMap[:, 0, :, :]
|
29 |
+
heat = torch.sigmoid(center)
|
30 |
+
hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2)
|
31 |
+
keep = (hmax == heat).float()
|
32 |
+
heat = heat * keep
|
33 |
+
heat = heat.reshape(-1, )
|
34 |
+
|
35 |
+
scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
|
36 |
+
yy = torch.floor_divide(indices, w).unsqueeze(-1)
|
37 |
+
xx = torch.fmod(indices, w).unsqueeze(-1)
|
38 |
+
ptss = torch.cat((yy, xx),dim=-1)
|
39 |
+
|
40 |
+
ptss = ptss.detach().cpu().numpy()
|
41 |
+
scores = scores.detach().cpu().numpy()
|
42 |
+
displacement = displacement.detach().cpu().numpy()
|
43 |
+
displacement = displacement.transpose((1,2,0))
|
44 |
+
return ptss, scores, displacement
|
45 |
+
|
46 |
+
|
47 |
+
def pred_lines(image, model,
|
48 |
+
input_shape=[512, 512],
|
49 |
+
score_thr=0.10,
|
50 |
+
dist_thr=20.0):
|
51 |
+
h, w, _ = image.shape
|
52 |
+
h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
|
53 |
+
|
54 |
+
resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
|
55 |
+
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
|
56 |
+
|
57 |
+
resized_image = resized_image.transpose((2,0,1))
|
58 |
+
batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
|
59 |
+
batch_image = (batch_image / 127.5) - 1.0
|
60 |
+
|
61 |
+
batch_image = torch.from_numpy(batch_image).float().to("cuda:4")
|
62 |
+
outputs = model(batch_image)
|
63 |
+
pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
|
64 |
+
start = vmap[:, :, :2]
|
65 |
+
end = vmap[:, :, 2:]
|
66 |
+
dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
|
67 |
+
|
68 |
+
segments_list = []
|
69 |
+
for center, score in zip(pts, pts_score):
|
70 |
+
y, x = center
|
71 |
+
distance = dist_map[y, x]
|
72 |
+
if score > score_thr and distance > dist_thr:
|
73 |
+
disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
|
74 |
+
x_start = x + disp_x_start
|
75 |
+
y_start = y + disp_y_start
|
76 |
+
x_end = x + disp_x_end
|
77 |
+
y_end = y + disp_y_end
|
78 |
+
segments_list.append([x_start, y_start, x_end, y_end])
|
79 |
+
|
80 |
+
lines = 2 * np.array(segments_list) # 256 > 512
|
81 |
+
lines[:, 0] = lines[:, 0] * w_ratio
|
82 |
+
lines[:, 1] = lines[:, 1] * h_ratio
|
83 |
+
lines[:, 2] = lines[:, 2] * w_ratio
|
84 |
+
lines[:, 3] = lines[:, 3] * h_ratio
|
85 |
+
|
86 |
+
return lines
|
87 |
+
|
88 |
+
|
89 |
+
def pred_squares(image,
|
90 |
+
model,
|
91 |
+
input_shape=[512, 512],
|
92 |
+
params={'score': 0.06,
|
93 |
+
'outside_ratio': 0.28,
|
94 |
+
'inside_ratio': 0.45,
|
95 |
+
'w_overlap': 0.0,
|
96 |
+
'w_degree': 1.95,
|
97 |
+
'w_length': 0.0,
|
98 |
+
'w_area': 1.86,
|
99 |
+
'w_center': 0.14}):
|
100 |
+
'''
|
101 |
+
shape = [height, width]
|
102 |
+
'''
|
103 |
+
h, w, _ = image.shape
|
104 |
+
original_shape = [h, w]
|
105 |
+
|
106 |
+
resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
|
107 |
+
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
|
108 |
+
resized_image = resized_image.transpose((2, 0, 1))
|
109 |
+
batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
|
110 |
+
batch_image = (batch_image / 127.5) - 1.0
|
111 |
+
|
112 |
+
batch_image = torch.from_numpy(batch_image).float().cuda()
|
113 |
+
outputs = model(batch_image)
|
114 |
+
|
115 |
+
pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
|
116 |
+
start = vmap[:, :, :2] # (x, y)
|
117 |
+
end = vmap[:, :, 2:] # (x, y)
|
118 |
+
dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
|
119 |
+
|
120 |
+
junc_list = []
|
121 |
+
segments_list = []
|
122 |
+
for junc, score in zip(pts, pts_score):
|
123 |
+
y, x = junc
|
124 |
+
distance = dist_map[y, x]
|
125 |
+
if score > params['score'] and distance > 20.0:
|
126 |
+
junc_list.append([x, y])
|
127 |
+
disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
|
128 |
+
d_arrow = 1.0
|
129 |
+
x_start = x + d_arrow * disp_x_start
|
130 |
+
y_start = y + d_arrow * disp_y_start
|
131 |
+
x_end = x + d_arrow * disp_x_end
|
132 |
+
y_end = y + d_arrow * disp_y_end
|
133 |
+
segments_list.append([x_start, y_start, x_end, y_end])
|
134 |
+
|
135 |
+
segments = np.array(segments_list)
|
136 |
+
|
137 |
+
####### post processing for squares
|
138 |
+
# 1. get unique lines
|
139 |
+
point = np.array([[0, 0]])
|
140 |
+
point = point[0]
|
141 |
+
start = segments[:, :2]
|
142 |
+
end = segments[:, 2:]
|
143 |
+
diff = start - end
|
144 |
+
a = diff[:, 1]
|
145 |
+
b = -diff[:, 0]
|
146 |
+
c = a * start[:, 0] + b * start[:, 1]
|
147 |
+
|
148 |
+
d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10)
|
149 |
+
theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi
|
150 |
+
theta[theta < 0.0] += 180
|
151 |
+
hough = np.concatenate([d[:, None], theta[:, None]], axis=-1)
|
152 |
+
|
153 |
+
d_quant = 1
|
154 |
+
theta_quant = 2
|
155 |
+
hough[:, 0] //= d_quant
|
156 |
+
hough[:, 1] //= theta_quant
|
157 |
+
_, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True)
|
158 |
+
|
159 |
+
acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32')
|
160 |
+
idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1
|
161 |
+
yx_indices = hough[indices, :].astype('int32')
|
162 |
+
acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts
|
163 |
+
idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices
|
164 |
+
|
165 |
+
acc_map_np = acc_map
|
166 |
+
# acc_map = acc_map[None, :, :, None]
|
167 |
+
#
|
168 |
+
# ### fast suppression using tensorflow op
|
169 |
+
# acc_map = tf.constant(acc_map, dtype=tf.float32)
|
170 |
+
# max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map)
|
171 |
+
# acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32)
|
172 |
+
# flatten_acc_map = tf.reshape(acc_map, [1, -1])
|
173 |
+
# topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts))
|
174 |
+
# _, h, w, _ = acc_map.shape
|
175 |
+
# y = tf.expand_dims(topk_indices // w, axis=-1)
|
176 |
+
# x = tf.expand_dims(topk_indices % w, axis=-1)
|
177 |
+
# yx = tf.concat([y, x], axis=-1)
|
178 |
+
|
179 |
+
### fast suppression using pytorch op
|
180 |
+
acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0)
|
181 |
+
_,_, h, w = acc_map.shape
|
182 |
+
max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2)
|
183 |
+
acc_map = acc_map * ( (acc_map == max_acc_map).float() )
|
184 |
+
flatten_acc_map = acc_map.reshape([-1, ])
|
185 |
+
|
186 |
+
scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True)
|
187 |
+
yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1)
|
188 |
+
xx = torch.fmod(indices, w).unsqueeze(-1)
|
189 |
+
yx = torch.cat((yy, xx), dim=-1)
|
190 |
+
|
191 |
+
yx = yx.detach().cpu().numpy()
|
192 |
+
|
193 |
+
topk_values = scores.detach().cpu().numpy()
|
194 |
+
indices = idx_map[yx[:, 0], yx[:, 1]]
|
195 |
+
basis = 5 // 2
|
196 |
+
|
197 |
+
merged_segments = []
|
198 |
+
for yx_pt, max_indice, value in zip(yx, indices, topk_values):
|
199 |
+
y, x = yx_pt
|
200 |
+
if max_indice == -1 or value == 0:
|
201 |
+
continue
|
202 |
+
segment_list = []
|
203 |
+
for y_offset in range(-basis, basis + 1):
|
204 |
+
for x_offset in range(-basis, basis + 1):
|
205 |
+
indice = idx_map[y + y_offset, x + x_offset]
|
206 |
+
cnt = int(acc_map_np[y + y_offset, x + x_offset])
|
207 |
+
if indice != -1:
|
208 |
+
segment_list.append(segments[indice])
|
209 |
+
if cnt > 1:
|
210 |
+
check_cnt = 1
|
211 |
+
current_hough = hough[indice]
|
212 |
+
for new_indice, new_hough in enumerate(hough):
|
213 |
+
if (current_hough == new_hough).all() and indice != new_indice:
|
214 |
+
segment_list.append(segments[new_indice])
|
215 |
+
check_cnt += 1
|
216 |
+
if check_cnt == cnt:
|
217 |
+
break
|
218 |
+
group_segments = np.array(segment_list).reshape([-1, 2])
|
219 |
+
sorted_group_segments = np.sort(group_segments, axis=0)
|
220 |
+
x_min, y_min = sorted_group_segments[0, :]
|
221 |
+
x_max, y_max = sorted_group_segments[-1, :]
|
222 |
+
|
223 |
+
deg = theta[max_indice]
|
224 |
+
if deg >= 90:
|
225 |
+
merged_segments.append([x_min, y_max, x_max, y_min])
|
226 |
+
else:
|
227 |
+
merged_segments.append([x_min, y_min, x_max, y_max])
|
228 |
+
|
229 |
+
# 2. get intersections
|
230 |
+
new_segments = np.array(merged_segments) # (x1, y1, x2, y2)
|
231 |
+
start = new_segments[:, :2] # (x1, y1)
|
232 |
+
end = new_segments[:, 2:] # (x2, y2)
|
233 |
+
new_centers = (start + end) / 2.0
|
234 |
+
diff = start - end
|
235 |
+
dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1))
|
236 |
+
|
237 |
+
# ax + by = c
|
238 |
+
a = diff[:, 1]
|
239 |
+
b = -diff[:, 0]
|
240 |
+
c = a * start[:, 0] + b * start[:, 1]
|
241 |
+
pre_det = a[:, None] * b[None, :]
|
242 |
+
det = pre_det - np.transpose(pre_det)
|
243 |
+
|
244 |
+
pre_inter_y = a[:, None] * c[None, :]
|
245 |
+
inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10)
|
246 |
+
pre_inter_x = c[:, None] * b[None, :]
|
247 |
+
inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10)
|
248 |
+
inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32')
|
249 |
+
|
250 |
+
# 3. get corner information
|
251 |
+
# 3.1 get distance
|
252 |
+
'''
|
253 |
+
dist_segments:
|
254 |
+
| dist(0), dist(1), dist(2), ...|
|
255 |
+
dist_inter_to_segment1:
|
256 |
+
| dist(inter,0), dist(inter,0), dist(inter,0), ... |
|
257 |
+
| dist(inter,1), dist(inter,1), dist(inter,1), ... |
|
258 |
+
...
|
259 |
+
dist_inter_to_semgnet2:
|
260 |
+
| dist(inter,0), dist(inter,1), dist(inter,2), ... |
|
261 |
+
| dist(inter,0), dist(inter,1), dist(inter,2), ... |
|
262 |
+
...
|
263 |
+
'''
|
264 |
+
|
265 |
+
dist_inter_to_segment1_start = np.sqrt(
|
266 |
+
np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
267 |
+
dist_inter_to_segment1_end = np.sqrt(
|
268 |
+
np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
269 |
+
dist_inter_to_segment2_start = np.sqrt(
|
270 |
+
np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
271 |
+
dist_inter_to_segment2_end = np.sqrt(
|
272 |
+
np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
273 |
+
|
274 |
+
# sort ascending
|
275 |
+
dist_inter_to_segment1 = np.sort(
|
276 |
+
np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1),
|
277 |
+
axis=-1) # [n_batch, n_batch, 2]
|
278 |
+
dist_inter_to_segment2 = np.sort(
|
279 |
+
np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1),
|
280 |
+
axis=-1) # [n_batch, n_batch, 2]
|
281 |
+
|
282 |
+
# 3.2 get degree
|
283 |
+
inter_to_start = new_centers[:, None, :] - inter_pts
|
284 |
+
deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi
|
285 |
+
deg_inter_to_start[deg_inter_to_start < 0.0] += 360
|
286 |
+
inter_to_end = new_centers[None, :, :] - inter_pts
|
287 |
+
deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi
|
288 |
+
deg_inter_to_end[deg_inter_to_end < 0.0] += 360
|
289 |
+
|
290 |
+
'''
|
291 |
+
B -- G
|
292 |
+
| |
|
293 |
+
C -- R
|
294 |
+
B : blue / G: green / C: cyan / R: red
|
295 |
+
|
296 |
+
0 -- 1
|
297 |
+
| |
|
298 |
+
3 -- 2
|
299 |
+
'''
|
300 |
+
# rename variables
|
301 |
+
deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end
|
302 |
+
# sort deg ascending
|
303 |
+
deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1)
|
304 |
+
|
305 |
+
deg_diff_map = np.abs(deg1_map - deg2_map)
|
306 |
+
# we only consider the smallest degree of intersect
|
307 |
+
deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180]
|
308 |
+
|
309 |
+
# define available degree range
|
310 |
+
deg_range = [60, 120]
|
311 |
+
|
312 |
+
corner_dict = {corner_info: [] for corner_info in range(4)}
|
313 |
+
inter_points = []
|
314 |
+
for i in range(inter_pts.shape[0]):
|
315 |
+
for j in range(i + 1, inter_pts.shape[1]):
|
316 |
+
# i, j > line index, always i < j
|
317 |
+
x, y = inter_pts[i, j, :]
|
318 |
+
deg1, deg2 = deg_sort[i, j, :]
|
319 |
+
deg_diff = deg_diff_map[i, j]
|
320 |
+
|
321 |
+
check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1]
|
322 |
+
|
323 |
+
outside_ratio = params['outside_ratio'] # over ratio >>> drop it!
|
324 |
+
inside_ratio = params['inside_ratio'] # over ratio >>> drop it!
|
325 |
+
check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \
|
326 |
+
dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \
|
327 |
+
(dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \
|
328 |
+
dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \
|
329 |
+
((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \
|
330 |
+
dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \
|
331 |
+
(dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \
|
332 |
+
dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio))
|
333 |
+
|
334 |
+
if check_degree and check_distance:
|
335 |
+
corner_info = None
|
336 |
+
|
337 |
+
if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \
|
338 |
+
(deg2 >= 315 and deg1 >= 45 and deg1 <= 120):
|
339 |
+
corner_info, color_info = 0, 'blue'
|
340 |
+
elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225):
|
341 |
+
corner_info, color_info = 1, 'green'
|
342 |
+
elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315):
|
343 |
+
corner_info, color_info = 2, 'black'
|
344 |
+
elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \
|
345 |
+
(deg2 >= 315 and deg1 >= 225 and deg1 <= 315):
|
346 |
+
corner_info, color_info = 3, 'cyan'
|
347 |
+
else:
|
348 |
+
corner_info, color_info = 4, 'red' # we don't use it
|
349 |
+
continue
|
350 |
+
|
351 |
+
corner_dict[corner_info].append([x, y, i, j])
|
352 |
+
inter_points.append([x, y])
|
353 |
+
|
354 |
+
square_list = []
|
355 |
+
connect_list = []
|
356 |
+
segments_list = []
|
357 |
+
for corner0 in corner_dict[0]:
|
358 |
+
for corner1 in corner_dict[1]:
|
359 |
+
connect01 = False
|
360 |
+
for corner0_line in corner0[2:]:
|
361 |
+
if corner0_line in corner1[2:]:
|
362 |
+
connect01 = True
|
363 |
+
break
|
364 |
+
if connect01:
|
365 |
+
for corner2 in corner_dict[2]:
|
366 |
+
connect12 = False
|
367 |
+
for corner1_line in corner1[2:]:
|
368 |
+
if corner1_line in corner2[2:]:
|
369 |
+
connect12 = True
|
370 |
+
break
|
371 |
+
if connect12:
|
372 |
+
for corner3 in corner_dict[3]:
|
373 |
+
connect23 = False
|
374 |
+
for corner2_line in corner2[2:]:
|
375 |
+
if corner2_line in corner3[2:]:
|
376 |
+
connect23 = True
|
377 |
+
break
|
378 |
+
if connect23:
|
379 |
+
for corner3_line in corner3[2:]:
|
380 |
+
if corner3_line in corner0[2:]:
|
381 |
+
# SQUARE!!!
|
382 |
+
'''
|
383 |
+
0 -- 1
|
384 |
+
| |
|
385 |
+
3 -- 2
|
386 |
+
square_list:
|
387 |
+
order: 0 > 1 > 2 > 3
|
388 |
+
| x0, y0, x1, y1, x2, y2, x3, y3 |
|
389 |
+
| x0, y0, x1, y1, x2, y2, x3, y3 |
|
390 |
+
...
|
391 |
+
connect_list:
|
392 |
+
order: 01 > 12 > 23 > 30
|
393 |
+
| line_idx01, line_idx12, line_idx23, line_idx30 |
|
394 |
+
| line_idx01, line_idx12, line_idx23, line_idx30 |
|
395 |
+
...
|
396 |
+
segments_list:
|
397 |
+
order: 0 > 1 > 2 > 3
|
398 |
+
| line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
|
399 |
+
| line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
|
400 |
+
...
|
401 |
+
'''
|
402 |
+
square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2])
|
403 |
+
connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line])
|
404 |
+
segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:])
|
405 |
+
|
406 |
+
def check_outside_inside(segments_info, connect_idx):
|
407 |
+
# return 'outside or inside', min distance, cover_param, peri_param
|
408 |
+
if connect_idx == segments_info[0]:
|
409 |
+
check_dist_mat = dist_inter_to_segment1
|
410 |
+
else:
|
411 |
+
check_dist_mat = dist_inter_to_segment2
|
412 |
+
|
413 |
+
i, j = segments_info
|
414 |
+
min_dist, max_dist = check_dist_mat[i, j, :]
|
415 |
+
connect_dist = dist_segments[connect_idx]
|
416 |
+
if max_dist > connect_dist:
|
417 |
+
return 'outside', min_dist, 0, 1
|
418 |
+
else:
|
419 |
+
return 'inside', min_dist, -1, -1
|
420 |
+
|
421 |
+
top_square = None
|
422 |
+
|
423 |
+
try:
|
424 |
+
map_size = input_shape[0] / 2
|
425 |
+
squares = np.array(square_list).reshape([-1, 4, 2])
|
426 |
+
score_array = []
|
427 |
+
connect_array = np.array(connect_list)
|
428 |
+
segments_array = np.array(segments_list).reshape([-1, 4, 2])
|
429 |
+
|
430 |
+
# get degree of corners:
|
431 |
+
squares_rollup = np.roll(squares, 1, axis=1)
|
432 |
+
squares_rolldown = np.roll(squares, -1, axis=1)
|
433 |
+
vec1 = squares_rollup - squares
|
434 |
+
normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10)
|
435 |
+
vec2 = squares_rolldown - squares
|
436 |
+
normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10)
|
437 |
+
inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4]
|
438 |
+
squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4]
|
439 |
+
|
440 |
+
# get square score
|
441 |
+
overlap_scores = []
|
442 |
+
degree_scores = []
|
443 |
+
length_scores = []
|
444 |
+
|
445 |
+
for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree):
|
446 |
+
'''
|
447 |
+
0 -- 1
|
448 |
+
| |
|
449 |
+
3 -- 2
|
450 |
+
|
451 |
+
# segments: [4, 2]
|
452 |
+
# connects: [4]
|
453 |
+
'''
|
454 |
+
|
455 |
+
###################################### OVERLAP SCORES
|
456 |
+
cover = 0
|
457 |
+
perimeter = 0
|
458 |
+
# check 0 > 1 > 2 > 3
|
459 |
+
square_length = []
|
460 |
+
|
461 |
+
for start_idx in range(4):
|
462 |
+
end_idx = (start_idx + 1) % 4
|
463 |
+
|
464 |
+
connect_idx = connects[start_idx] # segment idx of segment01
|
465 |
+
start_segments = segments[start_idx]
|
466 |
+
end_segments = segments[end_idx]
|
467 |
+
|
468 |
+
start_point = square[start_idx]
|
469 |
+
end_point = square[end_idx]
|
470 |
+
|
471 |
+
# check whether outside or inside
|
472 |
+
start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments,
|
473 |
+
connect_idx)
|
474 |
+
end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx)
|
475 |
+
|
476 |
+
cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min
|
477 |
+
perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min
|
478 |
+
|
479 |
+
square_length.append(
|
480 |
+
dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min)
|
481 |
+
|
482 |
+
overlap_scores.append(cover / perimeter)
|
483 |
+
######################################
|
484 |
+
###################################### DEGREE SCORES
|
485 |
+
'''
|
486 |
+
deg0 vs deg2
|
487 |
+
deg1 vs deg3
|
488 |
+
'''
|
489 |
+
deg0, deg1, deg2, deg3 = degree
|
490 |
+
deg_ratio1 = deg0 / deg2
|
491 |
+
if deg_ratio1 > 1.0:
|
492 |
+
deg_ratio1 = 1 / deg_ratio1
|
493 |
+
deg_ratio2 = deg1 / deg3
|
494 |
+
if deg_ratio2 > 1.0:
|
495 |
+
deg_ratio2 = 1 / deg_ratio2
|
496 |
+
degree_scores.append((deg_ratio1 + deg_ratio2) / 2)
|
497 |
+
######################################
|
498 |
+
###################################### LENGTH SCORES
|
499 |
+
'''
|
500 |
+
len0 vs len2
|
501 |
+
len1 vs len3
|
502 |
+
'''
|
503 |
+
len0, len1, len2, len3 = square_length
|
504 |
+
len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0
|
505 |
+
len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1
|
506 |
+
length_scores.append((len_ratio1 + len_ratio2) / 2)
|
507 |
+
|
508 |
+
######################################
|
509 |
+
|
510 |
+
overlap_scores = np.array(overlap_scores)
|
511 |
+
overlap_scores /= np.max(overlap_scores)
|
512 |
+
|
513 |
+
degree_scores = np.array(degree_scores)
|
514 |
+
# degree_scores /= np.max(degree_scores)
|
515 |
+
|
516 |
+
length_scores = np.array(length_scores)
|
517 |
+
|
518 |
+
###################################### AREA SCORES
|
519 |
+
area_scores = np.reshape(squares, [-1, 4, 2])
|
520 |
+
area_x = area_scores[:, :, 0]
|
521 |
+
area_y = area_scores[:, :, 1]
|
522 |
+
correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0]
|
523 |
+
area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1)
|
524 |
+
area_scores = 0.5 * np.abs(area_scores + correction)
|
525 |
+
area_scores /= (map_size * map_size) # np.max(area_scores)
|
526 |
+
######################################
|
527 |
+
|
528 |
+
###################################### CENTER SCORES
|
529 |
+
centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2]
|
530 |
+
# squares: [n, 4, 2]
|
531 |
+
square_centers = np.mean(squares, axis=1) # [n, 2]
|
532 |
+
center2center = np.sqrt(np.sum((centers - square_centers) ** 2))
|
533 |
+
center_scores = center2center / (map_size / np.sqrt(2.0))
|
534 |
+
|
535 |
+
'''
|
536 |
+
score_w = [overlap, degree, area, center, length]
|
537 |
+
'''
|
538 |
+
score_w = [0.0, 1.0, 10.0, 0.5, 1.0]
|
539 |
+
score_array = params['w_overlap'] * overlap_scores \
|
540 |
+
+ params['w_degree'] * degree_scores \
|
541 |
+
+ params['w_area'] * area_scores \
|
542 |
+
- params['w_center'] * center_scores \
|
543 |
+
+ params['w_length'] * length_scores
|
544 |
+
|
545 |
+
best_square = []
|
546 |
+
|
547 |
+
sorted_idx = np.argsort(score_array)[::-1]
|
548 |
+
score_array = score_array[sorted_idx]
|
549 |
+
squares = squares[sorted_idx]
|
550 |
+
|
551 |
+
except Exception as e:
|
552 |
+
pass
|
553 |
+
|
554 |
+
'''return list
|
555 |
+
merged_lines, squares, scores
|
556 |
+
'''
|
557 |
+
|
558 |
+
try:
|
559 |
+
new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1]
|
560 |
+
new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0]
|
561 |
+
new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1]
|
562 |
+
new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0]
|
563 |
+
except:
|
564 |
+
new_segments = []
|
565 |
+
|
566 |
+
try:
|
567 |
+
squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1]
|
568 |
+
squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0]
|
569 |
+
except:
|
570 |
+
squares = []
|
571 |
+
score_array = []
|
572 |
+
|
573 |
+
try:
|
574 |
+
inter_points = np.array(inter_points)
|
575 |
+
inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1]
|
576 |
+
inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0]
|
577 |
+
except:
|
578 |
+
inter_points = []
|
579 |
+
|
580 |
+
return new_segments, squares, score_array, inter_points
|
src/flux/annotator/tile/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import cv2
|
3 |
+
from .guided_filter import FastGuidedFilter
|
4 |
+
|
5 |
+
|
6 |
+
class TileDetector:
|
7 |
+
# https://huggingface.co/xinsir/controlnet-tile-sdxl-1.0
|
8 |
+
def __init__(self):
|
9 |
+
pass
|
10 |
+
|
11 |
+
def __call__(self, image):
|
12 |
+
blur_strength = random.sample([i / 10. for i in range(10, 201, 2)], k=1)[0]
|
13 |
+
radius = random.sample([i for i in range(1, 40, 2)], k=1)[0]
|
14 |
+
eps = random.sample([i / 1000. for i in range(1, 101, 2)], k=1)[0]
|
15 |
+
scale_factor = random.sample([i / 10. for i in range(10, 181, 5)], k=1)[0]
|
16 |
+
|
17 |
+
ksize = int(blur_strength)
|
18 |
+
if ksize % 2 == 0:
|
19 |
+
ksize += 1
|
20 |
+
|
21 |
+
if random.random() > 0.5:
|
22 |
+
image = cv2.GaussianBlur(image, (ksize, ksize), blur_strength / 2)
|
23 |
+
if random.random() > 0.5:
|
24 |
+
filter = FastGuidedFilter(image, radius, eps, scale_factor)
|
25 |
+
image = filter.filter(image)
|
26 |
+
return image
|
src/flux/annotator/tile/guided_filter.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
## @package guided_filter.core.filters
|
3 |
+
#
|
4 |
+
# Implementation of guided filter.
|
5 |
+
# * GuidedFilter: Original guided filter.
|
6 |
+
# * FastGuidedFilter: Fast version of the guided filter.
|
7 |
+
# @author tody
|
8 |
+
# @date 2015/08/26
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import cv2
|
12 |
+
|
13 |
+
## Convert image into float32 type.
|
14 |
+
def to32F(img):
|
15 |
+
if img.dtype == np.float32:
|
16 |
+
return img
|
17 |
+
return (1.0 / 255.0) * np.float32(img)
|
18 |
+
|
19 |
+
## Convert image into uint8 type.
|
20 |
+
def to8U(img):
|
21 |
+
if img.dtype == np.uint8:
|
22 |
+
return img
|
23 |
+
return np.clip(np.uint8(255.0 * img), 0, 255)
|
24 |
+
|
25 |
+
## Return if the input image is gray or not.
|
26 |
+
def _isGray(I):
|
27 |
+
return len(I.shape) == 2
|
28 |
+
|
29 |
+
|
30 |
+
## Return down sampled image.
|
31 |
+
# @param scale (w/s, h/s) image will be created.
|
32 |
+
# @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
|
33 |
+
def _downSample(I, scale=4, shape=None):
|
34 |
+
if shape is not None:
|
35 |
+
h, w = shape
|
36 |
+
return cv2.resize(I, (w, h), interpolation=cv2.INTER_NEAREST)
|
37 |
+
|
38 |
+
h, w = I.shape[:2]
|
39 |
+
return cv2.resize(I, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_NEAREST)
|
40 |
+
|
41 |
+
|
42 |
+
## Return up sampled image.
|
43 |
+
# @param scale (w*s, h*s) image will be created.
|
44 |
+
# @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
|
45 |
+
def _upSample(I, scale=2, shape=None):
|
46 |
+
if shape is not None:
|
47 |
+
h, w = shape
|
48 |
+
return cv2.resize(I, (w, h), interpolation=cv2.INTER_LINEAR)
|
49 |
+
|
50 |
+
h, w = I.shape[:2]
|
51 |
+
return cv2.resize(I, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
|
52 |
+
|
53 |
+
## Fast guide filter.
|
54 |
+
class FastGuidedFilter:
|
55 |
+
## Constructor.
|
56 |
+
# @param I Input guidance image. Color or gray.
|
57 |
+
# @param radius Radius of Guided Filter.
|
58 |
+
# @param epsilon Regularization term of Guided Filter.
|
59 |
+
# @param scale Down sampled scale.
|
60 |
+
def __init__(self, I, radius=5, epsilon=0.4, scale=4):
|
61 |
+
I_32F = to32F(I)
|
62 |
+
self._I = I_32F
|
63 |
+
h, w = I.shape[:2]
|
64 |
+
|
65 |
+
I_sub = _downSample(I_32F, scale)
|
66 |
+
|
67 |
+
self._I_sub = I_sub
|
68 |
+
radius = int(radius / scale)
|
69 |
+
|
70 |
+
if _isGray(I):
|
71 |
+
self._guided_filter = GuidedFilterGray(I_sub, radius, epsilon)
|
72 |
+
else:
|
73 |
+
self._guided_filter = GuidedFilterColor(I_sub, radius, epsilon)
|
74 |
+
|
75 |
+
## Apply filter for the input image.
|
76 |
+
# @param p Input image for the filtering.
|
77 |
+
def filter(self, p):
|
78 |
+
p_32F = to32F(p)
|
79 |
+
shape_original = p.shape[:2]
|
80 |
+
|
81 |
+
p_sub = _downSample(p_32F, shape=self._I_sub.shape[:2])
|
82 |
+
|
83 |
+
if _isGray(p_sub):
|
84 |
+
return self._filterGray(p_sub, shape_original)
|
85 |
+
|
86 |
+
cs = p.shape[2]
|
87 |
+
q = np.array(p_32F)
|
88 |
+
|
89 |
+
for ci in range(cs):
|
90 |
+
q[:, :, ci] = self._filterGray(p_sub[:, :, ci], shape_original)
|
91 |
+
return to8U(q)
|
92 |
+
|
93 |
+
def _filterGray(self, p_sub, shape_original):
|
94 |
+
ab_sub = self._guided_filter._computeCoefficients(p_sub)
|
95 |
+
ab = [_upSample(abi, shape=shape_original) for abi in ab_sub]
|
96 |
+
return self._guided_filter._computeOutput(ab, self._I)
|
97 |
+
|
98 |
+
|
99 |
+
## Guide filter.
|
100 |
+
class GuidedFilter:
|
101 |
+
## Constructor.
|
102 |
+
# @param I Input guidance image. Color or gray.
|
103 |
+
# @param radius Radius of Guided Filter.
|
104 |
+
# @param epsilon Regularization term of Guided Filter.
|
105 |
+
def __init__(self, I, radius=5, epsilon=0.4):
|
106 |
+
I_32F = to32F(I)
|
107 |
+
|
108 |
+
if _isGray(I):
|
109 |
+
self._guided_filter = GuidedFilterGray(I_32F, radius, epsilon)
|
110 |
+
else:
|
111 |
+
self._guided_filter = GuidedFilterColor(I_32F, radius, epsilon)
|
112 |
+
|
113 |
+
## Apply filter for the input image.
|
114 |
+
# @param p Input image for the filtering.
|
115 |
+
def filter(self, p):
|
116 |
+
return to8U(self._guided_filter.filter(p))
|
117 |
+
|
118 |
+
|
119 |
+
## Common parts of guided filter.
|
120 |
+
#
|
121 |
+
# This class is used by guided_filter class. GuidedFilterGray and GuidedFilterColor.
|
122 |
+
# Based on guided_filter._computeCoefficients, guided_filter._computeOutput,
|
123 |
+
# GuidedFilterCommon.filter computes filtered image for color and gray.
|
124 |
+
class GuidedFilterCommon:
|
125 |
+
def __init__(self, guided_filter):
|
126 |
+
self._guided_filter = guided_filter
|
127 |
+
|
128 |
+
## Apply filter for the input image.
|
129 |
+
# @param p Input image for the filtering.
|
130 |
+
def filter(self, p):
|
131 |
+
p_32F = to32F(p)
|
132 |
+
if _isGray(p_32F):
|
133 |
+
return self._filterGray(p_32F)
|
134 |
+
|
135 |
+
cs = p.shape[2]
|
136 |
+
q = np.array(p_32F)
|
137 |
+
|
138 |
+
for ci in range(cs):
|
139 |
+
q[:, :, ci] = self._filterGray(p_32F[:, :, ci])
|
140 |
+
return q
|
141 |
+
|
142 |
+
def _filterGray(self, p):
|
143 |
+
ab = self._guided_filter._computeCoefficients(p)
|
144 |
+
return self._guided_filter._computeOutput(ab, self._guided_filter._I)
|
145 |
+
|
146 |
+
|
147 |
+
## Guided filter for gray guidance image.
|
148 |
+
class GuidedFilterGray:
|
149 |
+
# @param I Input gray guidance image.
|
150 |
+
# @param radius Radius of Guided Filter.
|
151 |
+
# @param epsilon Regularization term of Guided Filter.
|
152 |
+
def __init__(self, I, radius=5, epsilon=0.4):
|
153 |
+
self._radius = 2 * radius + 1
|
154 |
+
self._epsilon = epsilon
|
155 |
+
self._I = to32F(I)
|
156 |
+
self._initFilter()
|
157 |
+
self._filter_common = GuidedFilterCommon(self)
|
158 |
+
|
159 |
+
## Apply filter for the input image.
|
160 |
+
# @param p Input image for the filtering.
|
161 |
+
def filter(self, p):
|
162 |
+
return self._filter_common.filter(p)
|
163 |
+
|
164 |
+
def _initFilter(self):
|
165 |
+
I = self._I
|
166 |
+
r = self._radius
|
167 |
+
self._I_mean = cv2.blur(I, (r, r))
|
168 |
+
I_mean_sq = cv2.blur(I ** 2, (r, r))
|
169 |
+
self._I_var = I_mean_sq - self._I_mean ** 2
|
170 |
+
|
171 |
+
def _computeCoefficients(self, p):
|
172 |
+
r = self._radius
|
173 |
+
p_mean = cv2.blur(p, (r, r))
|
174 |
+
p_cov = p_mean - self._I_mean * p_mean
|
175 |
+
a = p_cov / (self._I_var + self._epsilon)
|
176 |
+
b = p_mean - a * self._I_mean
|
177 |
+
a_mean = cv2.blur(a, (r, r))
|
178 |
+
b_mean = cv2.blur(b, (r, r))
|
179 |
+
return a_mean, b_mean
|
180 |
+
|
181 |
+
def _computeOutput(self, ab, I):
|
182 |
+
a_mean, b_mean = ab
|
183 |
+
return a_mean * I + b_mean
|
184 |
+
|
185 |
+
|
186 |
+
## Guided filter for color guidance image.
|
187 |
+
class GuidedFilterColor:
|
188 |
+
# @param I Input color guidance image.
|
189 |
+
# @param radius Radius of Guided Filter.
|
190 |
+
# @param epsilon Regularization term of Guided Filter.
|
191 |
+
def __init__(self, I, radius=5, epsilon=0.2):
|
192 |
+
self._radius = 2 * radius + 1
|
193 |
+
self._epsilon = epsilon
|
194 |
+
self._I = to32F(I)
|
195 |
+
self._initFilter()
|
196 |
+
self._filter_common = GuidedFilterCommon(self)
|
197 |
+
|
198 |
+
## Apply filter for the input image.
|
199 |
+
# @param p Input image for the filtering.
|
200 |
+
def filter(self, p):
|
201 |
+
return self._filter_common.filter(p)
|
202 |
+
|
203 |
+
def _initFilter(self):
|
204 |
+
I = self._I
|
205 |
+
r = self._radius
|
206 |
+
eps = self._epsilon
|
207 |
+
|
208 |
+
Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
|
209 |
+
|
210 |
+
self._Ir_mean = cv2.blur(Ir, (r, r))
|
211 |
+
self._Ig_mean = cv2.blur(Ig, (r, r))
|
212 |
+
self._Ib_mean = cv2.blur(Ib, (r, r))
|
213 |
+
|
214 |
+
Irr_var = cv2.blur(Ir ** 2, (r, r)) - self._Ir_mean ** 2 + eps
|
215 |
+
Irg_var = cv2.blur(Ir * Ig, (r, r)) - self._Ir_mean * self._Ig_mean
|
216 |
+
Irb_var = cv2.blur(Ir * Ib, (r, r)) - self._Ir_mean * self._Ib_mean
|
217 |
+
Igg_var = cv2.blur(Ig * Ig, (r, r)) - self._Ig_mean * self._Ig_mean + eps
|
218 |
+
Igb_var = cv2.blur(Ig * Ib, (r, r)) - self._Ig_mean * self._Ib_mean
|
219 |
+
Ibb_var = cv2.blur(Ib * Ib, (r, r)) - self._Ib_mean * self._Ib_mean + eps
|
220 |
+
|
221 |
+
Irr_inv = Igg_var * Ibb_var - Igb_var * Igb_var
|
222 |
+
Irg_inv = Igb_var * Irb_var - Irg_var * Ibb_var
|
223 |
+
Irb_inv = Irg_var * Igb_var - Igg_var * Irb_var
|
224 |
+
Igg_inv = Irr_var * Ibb_var - Irb_var * Irb_var
|
225 |
+
Igb_inv = Irb_var * Irg_var - Irr_var * Igb_var
|
226 |
+
Ibb_inv = Irr_var * Igg_var - Irg_var * Irg_var
|
227 |
+
|
228 |
+
I_cov = Irr_inv * Irr_var + Irg_inv * Irg_var + Irb_inv * Irb_var
|
229 |
+
Irr_inv /= I_cov
|
230 |
+
Irg_inv /= I_cov
|
231 |
+
Irb_inv /= I_cov
|
232 |
+
Igg_inv /= I_cov
|
233 |
+
Igb_inv /= I_cov
|
234 |
+
Ibb_inv /= I_cov
|
235 |
+
|
236 |
+
self._Irr_inv = Irr_inv
|
237 |
+
self._Irg_inv = Irg_inv
|
238 |
+
self._Irb_inv = Irb_inv
|
239 |
+
self._Igg_inv = Igg_inv
|
240 |
+
self._Igb_inv = Igb_inv
|
241 |
+
self._Ibb_inv = Ibb_inv
|
242 |
+
|
243 |
+
def _computeCoefficients(self, p):
|
244 |
+
r = self._radius
|
245 |
+
I = self._I
|
246 |
+
Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
|
247 |
+
|
248 |
+
p_mean = cv2.blur(p, (r, r))
|
249 |
+
|
250 |
+
Ipr_mean = cv2.blur(Ir * p, (r, r))
|
251 |
+
Ipg_mean = cv2.blur(Ig * p, (r, r))
|
252 |
+
Ipb_mean = cv2.blur(Ib * p, (r, r))
|
253 |
+
|
254 |
+
Ipr_cov = Ipr_mean - self._Ir_mean * p_mean
|
255 |
+
Ipg_cov = Ipg_mean - self._Ig_mean * p_mean
|
256 |
+
Ipb_cov = Ipb_mean - self._Ib_mean * p_mean
|
257 |
+
|
258 |
+
ar = self._Irr_inv * Ipr_cov + self._Irg_inv * Ipg_cov + self._Irb_inv * Ipb_cov
|
259 |
+
ag = self._Irg_inv * Ipr_cov + self._Igg_inv * Ipg_cov + self._Igb_inv * Ipb_cov
|
260 |
+
ab = self._Irb_inv * Ipr_cov + self._Igb_inv * Ipg_cov + self._Ibb_inv * Ipb_cov
|
261 |
+
b = p_mean - ar * self._Ir_mean - ag * self._Ig_mean - ab * self._Ib_mean
|
262 |
+
|
263 |
+
ar_mean = cv2.blur(ar, (r, r))
|
264 |
+
ag_mean = cv2.blur(ag, (r, r))
|
265 |
+
ab_mean = cv2.blur(ab, (r, r))
|
266 |
+
b_mean = cv2.blur(b, (r, r))
|
267 |
+
|
268 |
+
return ar_mean, ag_mean, ab_mean, b_mean
|
269 |
+
|
270 |
+
def _computeOutput(self, ab, I):
|
271 |
+
ar_mean, ag_mean, ab_mean, b_mean = ab
|
272 |
+
|
273 |
+
Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
|
274 |
+
|
275 |
+
q = (ar_mean * Ir +
|
276 |
+
ag_mean * Ig +
|
277 |
+
ab_mean * Ib +
|
278 |
+
b_mean)
|
279 |
+
|
280 |
+
return q
|
src/flux/annotator/util.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
|
7 |
+
|
8 |
+
|
9 |
+
def HWC3(x):
|
10 |
+
assert x.dtype == np.uint8
|
11 |
+
if x.ndim == 2:
|
12 |
+
x = x[:, :, None]
|
13 |
+
assert x.ndim == 3
|
14 |
+
H, W, C = x.shape
|
15 |
+
assert C == 1 or C == 3 or C == 4
|
16 |
+
if C == 3:
|
17 |
+
return x
|
18 |
+
if C == 1:
|
19 |
+
return np.concatenate([x, x, x], axis=2)
|
20 |
+
if C == 4:
|
21 |
+
color = x[:, :, 0:3].astype(np.float32)
|
22 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
23 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
24 |
+
y = y.clip(0, 255).astype(np.uint8)
|
25 |
+
return y
|
26 |
+
|
27 |
+
|
28 |
+
def resize_image(input_image, resolution):
|
29 |
+
H, W, C = input_image.shape
|
30 |
+
H = float(H)
|
31 |
+
W = float(W)
|
32 |
+
k = float(resolution) / min(H, W)
|
33 |
+
H *= k
|
34 |
+
W *= k
|
35 |
+
H = int(np.round(H / 64.0)) * 64
|
36 |
+
W = int(np.round(W / 64.0)) * 64
|
37 |
+
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
38 |
+
return img
|
src/flux/annotator/zoe/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
src/flux/annotator/zoe/__init__.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ZoeDepth
|
2 |
+
# https://github.com/isl-org/ZoeDepth
|
3 |
+
|
4 |
+
import os
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from einops import rearrange
|
10 |
+
from .zoedepth.models.zoedepth.zoedepth_v1 import ZoeDepth
|
11 |
+
from .zoedepth.utils.config import get_config
|
12 |
+
from ...annotator.util import annotator_ckpts_path
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
+
|
15 |
+
|
16 |
+
class ZoeDetector:
|
17 |
+
def __init__(self):
|
18 |
+
model_path = os.path.join(annotator_ckpts_path, "ZoeD_M12_N.pt")
|
19 |
+
if not os.path.exists(model_path):
|
20 |
+
model_path = hf_hub_download("lllyasviel/Annotators", "ZoeD_M12_N.pt")
|
21 |
+
conf = get_config("zoedepth", "infer")
|
22 |
+
model = ZoeDepth.build_from_config(conf)
|
23 |
+
model.load_state_dict(torch.load(model_path)['model'], strict=False)
|
24 |
+
model = model.cuda()
|
25 |
+
model.device = 'cuda'
|
26 |
+
model.eval()
|
27 |
+
self.model = model
|
28 |
+
|
29 |
+
def __call__(self, input_image):
|
30 |
+
assert input_image.ndim == 3
|
31 |
+
image_depth = input_image
|
32 |
+
with torch.no_grad():
|
33 |
+
image_depth = torch.from_numpy(image_depth).float().cuda()
|
34 |
+
image_depth = image_depth / 255.0
|
35 |
+
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
36 |
+
depth = self.model.infer(image_depth)
|
37 |
+
|
38 |
+
depth = depth[0, 0].cpu().numpy()
|
39 |
+
|
40 |
+
vmin = np.percentile(depth, 2)
|
41 |
+
vmax = np.percentile(depth, 85)
|
42 |
+
|
43 |
+
depth -= vmin
|
44 |
+
depth /= vmax - vmin
|
45 |
+
depth = 1.0 - depth
|
46 |
+
depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)
|
47 |
+
|
48 |
+
return depth_image
|
src/flux/annotator/zoe/zoedepth/data/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
src/flux/annotator/zoe/zoedepth/data/data_mono.py
ADDED
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
# This file is partly inspired from BTS (https://github.com/cleinc/bts/blob/master/pytorch/bts_dataloader.py); author: Jin Han Lee
|
26 |
+
|
27 |
+
import itertools
|
28 |
+
import os
|
29 |
+
import random
|
30 |
+
|
31 |
+
import numpy as np
|
32 |
+
import cv2
|
33 |
+
import torch
|
34 |
+
import torch.nn as nn
|
35 |
+
import torch.utils.data.distributed
|
36 |
+
from zoedepth.utils.easydict import EasyDict as edict
|
37 |
+
from PIL import Image, ImageOps
|
38 |
+
from torch.utils.data import DataLoader, Dataset
|
39 |
+
from torchvision import transforms
|
40 |
+
|
41 |
+
from zoedepth.utils.config import change_dataset
|
42 |
+
|
43 |
+
from .ddad import get_ddad_loader
|
44 |
+
from .diml_indoor_test import get_diml_indoor_loader
|
45 |
+
from .diml_outdoor_test import get_diml_outdoor_loader
|
46 |
+
from .diode import get_diode_loader
|
47 |
+
from .hypersim import get_hypersim_loader
|
48 |
+
from .ibims import get_ibims_loader
|
49 |
+
from .sun_rgbd_loader import get_sunrgbd_loader
|
50 |
+
from .vkitti import get_vkitti_loader
|
51 |
+
from .vkitti2 import get_vkitti2_loader
|
52 |
+
|
53 |
+
from .preprocess import CropParams, get_white_border, get_black_border
|
54 |
+
|
55 |
+
|
56 |
+
def _is_pil_image(img):
|
57 |
+
return isinstance(img, Image.Image)
|
58 |
+
|
59 |
+
|
60 |
+
def _is_numpy_image(img):
|
61 |
+
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
62 |
+
|
63 |
+
|
64 |
+
def preprocessing_transforms(mode, **kwargs):
|
65 |
+
return transforms.Compose([
|
66 |
+
ToTensor(mode=mode, **kwargs)
|
67 |
+
])
|
68 |
+
|
69 |
+
|
70 |
+
class DepthDataLoader(object):
|
71 |
+
def __init__(self, config, mode, device='cpu', transform=None, **kwargs):
|
72 |
+
"""
|
73 |
+
Data loader for depth datasets
|
74 |
+
|
75 |
+
Args:
|
76 |
+
config (dict): Config dictionary. Refer to utils/config.py
|
77 |
+
mode (str): "train" or "online_eval"
|
78 |
+
device (str, optional): Device to load the data on. Defaults to 'cpu'.
|
79 |
+
transform (torchvision.transforms, optional): Transform to apply to the data. Defaults to None.
|
80 |
+
"""
|
81 |
+
|
82 |
+
self.config = config
|
83 |
+
|
84 |
+
if config.dataset == 'ibims':
|
85 |
+
self.data = get_ibims_loader(config, batch_size=1, num_workers=1)
|
86 |
+
return
|
87 |
+
|
88 |
+
if config.dataset == 'sunrgbd':
|
89 |
+
self.data = get_sunrgbd_loader(
|
90 |
+
data_dir_root=config.sunrgbd_root, batch_size=1, num_workers=1)
|
91 |
+
return
|
92 |
+
|
93 |
+
if config.dataset == 'diml_indoor':
|
94 |
+
self.data = get_diml_indoor_loader(
|
95 |
+
data_dir_root=config.diml_indoor_root, batch_size=1, num_workers=1)
|
96 |
+
return
|
97 |
+
|
98 |
+
if config.dataset == 'diml_outdoor':
|
99 |
+
self.data = get_diml_outdoor_loader(
|
100 |
+
data_dir_root=config.diml_outdoor_root, batch_size=1, num_workers=1)
|
101 |
+
return
|
102 |
+
|
103 |
+
if "diode" in config.dataset:
|
104 |
+
self.data = get_diode_loader(
|
105 |
+
config[config.dataset+"_root"], batch_size=1, num_workers=1)
|
106 |
+
return
|
107 |
+
|
108 |
+
if config.dataset == 'hypersim_test':
|
109 |
+
self.data = get_hypersim_loader(
|
110 |
+
config.hypersim_test_root, batch_size=1, num_workers=1)
|
111 |
+
return
|
112 |
+
|
113 |
+
if config.dataset == 'vkitti':
|
114 |
+
self.data = get_vkitti_loader(
|
115 |
+
config.vkitti_root, batch_size=1, num_workers=1)
|
116 |
+
return
|
117 |
+
|
118 |
+
if config.dataset == 'vkitti2':
|
119 |
+
self.data = get_vkitti2_loader(
|
120 |
+
config.vkitti2_root, batch_size=1, num_workers=1)
|
121 |
+
return
|
122 |
+
|
123 |
+
if config.dataset == 'ddad':
|
124 |
+
self.data = get_ddad_loader(config.ddad_root, resize_shape=(
|
125 |
+
352, 1216), batch_size=1, num_workers=1)
|
126 |
+
return
|
127 |
+
|
128 |
+
img_size = self.config.get("img_size", None)
|
129 |
+
img_size = img_size if self.config.get(
|
130 |
+
"do_input_resize", False) else None
|
131 |
+
|
132 |
+
if transform is None:
|
133 |
+
transform = preprocessing_transforms(mode, size=img_size)
|
134 |
+
|
135 |
+
if mode == 'train':
|
136 |
+
|
137 |
+
Dataset = DataLoadPreprocess
|
138 |
+
self.training_samples = Dataset(
|
139 |
+
config, mode, transform=transform, device=device)
|
140 |
+
|
141 |
+
if config.distributed:
|
142 |
+
self.train_sampler = torch.utils.data.distributed.DistributedSampler(
|
143 |
+
self.training_samples)
|
144 |
+
else:
|
145 |
+
self.train_sampler = None
|
146 |
+
|
147 |
+
self.data = DataLoader(self.training_samples,
|
148 |
+
batch_size=config.batch_size,
|
149 |
+
shuffle=(self.train_sampler is None),
|
150 |
+
num_workers=config.workers,
|
151 |
+
pin_memory=True,
|
152 |
+
persistent_workers=True,
|
153 |
+
# prefetch_factor=2,
|
154 |
+
sampler=self.train_sampler)
|
155 |
+
|
156 |
+
elif mode == 'online_eval':
|
157 |
+
self.testing_samples = DataLoadPreprocess(
|
158 |
+
config, mode, transform=transform)
|
159 |
+
if config.distributed: # redundant. here only for readability and to be more explicit
|
160 |
+
# Give whole test set to all processes (and report evaluation only on one) regardless
|
161 |
+
self.eval_sampler = None
|
162 |
+
else:
|
163 |
+
self.eval_sampler = None
|
164 |
+
self.data = DataLoader(self.testing_samples, 1,
|
165 |
+
shuffle=kwargs.get("shuffle_test", False),
|
166 |
+
num_workers=1,
|
167 |
+
pin_memory=False,
|
168 |
+
sampler=self.eval_sampler)
|
169 |
+
|
170 |
+
elif mode == 'test':
|
171 |
+
self.testing_samples = DataLoadPreprocess(
|
172 |
+
config, mode, transform=transform)
|
173 |
+
self.data = DataLoader(self.testing_samples,
|
174 |
+
1, shuffle=False, num_workers=1)
|
175 |
+
|
176 |
+
else:
|
177 |
+
print(
|
178 |
+
'mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
|
179 |
+
|
180 |
+
|
181 |
+
def repetitive_roundrobin(*iterables):
|
182 |
+
"""
|
183 |
+
cycles through iterables but sample wise
|
184 |
+
first yield first sample from first iterable then first sample from second iterable and so on
|
185 |
+
then second sample from first iterable then second sample from second iterable and so on
|
186 |
+
|
187 |
+
If one iterable is shorter than the others, it is repeated until all iterables are exhausted
|
188 |
+
repetitive_roundrobin('ABC', 'D', 'EF') --> A D E B D F C D E
|
189 |
+
"""
|
190 |
+
# Repetitive roundrobin
|
191 |
+
iterables_ = [iter(it) for it in iterables]
|
192 |
+
exhausted = [False] * len(iterables)
|
193 |
+
while not all(exhausted):
|
194 |
+
for i, it in enumerate(iterables_):
|
195 |
+
try:
|
196 |
+
yield next(it)
|
197 |
+
except StopIteration:
|
198 |
+
exhausted[i] = True
|
199 |
+
iterables_[i] = itertools.cycle(iterables[i])
|
200 |
+
# First elements may get repeated if one iterable is shorter than the others
|
201 |
+
yield next(iterables_[i])
|
202 |
+
|
203 |
+
|
204 |
+
class RepetitiveRoundRobinDataLoader(object):
|
205 |
+
def __init__(self, *dataloaders):
|
206 |
+
self.dataloaders = dataloaders
|
207 |
+
|
208 |
+
def __iter__(self):
|
209 |
+
return repetitive_roundrobin(*self.dataloaders)
|
210 |
+
|
211 |
+
def __len__(self):
|
212 |
+
# First samples get repeated, thats why the plus one
|
213 |
+
return len(self.dataloaders) * (max(len(dl) for dl in self.dataloaders) + 1)
|
214 |
+
|
215 |
+
|
216 |
+
class MixedNYUKITTI(object):
|
217 |
+
def __init__(self, config, mode, device='cpu', **kwargs):
|
218 |
+
config = edict(config)
|
219 |
+
config.workers = config.workers // 2
|
220 |
+
self.config = config
|
221 |
+
nyu_conf = change_dataset(edict(config), 'nyu')
|
222 |
+
kitti_conf = change_dataset(edict(config), 'kitti')
|
223 |
+
|
224 |
+
# make nyu default for testing
|
225 |
+
self.config = config = nyu_conf
|
226 |
+
img_size = self.config.get("img_size", None)
|
227 |
+
img_size = img_size if self.config.get(
|
228 |
+
"do_input_resize", False) else None
|
229 |
+
if mode == 'train':
|
230 |
+
nyu_loader = DepthDataLoader(
|
231 |
+
nyu_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
|
232 |
+
kitti_loader = DepthDataLoader(
|
233 |
+
kitti_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
|
234 |
+
# It has been changed to repetitive roundrobin
|
235 |
+
self.data = RepetitiveRoundRobinDataLoader(
|
236 |
+
nyu_loader, kitti_loader)
|
237 |
+
else:
|
238 |
+
self.data = DepthDataLoader(nyu_conf, mode, device=device).data
|
239 |
+
|
240 |
+
|
241 |
+
def remove_leading_slash(s):
|
242 |
+
if s[0] == '/' or s[0] == '\\':
|
243 |
+
return s[1:]
|
244 |
+
return s
|
245 |
+
|
246 |
+
|
247 |
+
class CachedReader:
|
248 |
+
def __init__(self, shared_dict=None):
|
249 |
+
if shared_dict:
|
250 |
+
self._cache = shared_dict
|
251 |
+
else:
|
252 |
+
self._cache = {}
|
253 |
+
|
254 |
+
def open(self, fpath):
|
255 |
+
im = self._cache.get(fpath, None)
|
256 |
+
if im is None:
|
257 |
+
im = self._cache[fpath] = Image.open(fpath)
|
258 |
+
return im
|
259 |
+
|
260 |
+
|
261 |
+
class ImReader:
|
262 |
+
def __init__(self):
|
263 |
+
pass
|
264 |
+
|
265 |
+
# @cache
|
266 |
+
def open(self, fpath):
|
267 |
+
return Image.open(fpath)
|
268 |
+
|
269 |
+
|
270 |
+
class DataLoadPreprocess(Dataset):
|
271 |
+
def __init__(self, config, mode, transform=None, is_for_online_eval=False, **kwargs):
|
272 |
+
self.config = config
|
273 |
+
if mode == 'online_eval':
|
274 |
+
with open(config.filenames_file_eval, 'r') as f:
|
275 |
+
self.filenames = f.readlines()
|
276 |
+
else:
|
277 |
+
with open(config.filenames_file, 'r') as f:
|
278 |
+
self.filenames = f.readlines()
|
279 |
+
|
280 |
+
self.mode = mode
|
281 |
+
self.transform = transform
|
282 |
+
self.to_tensor = ToTensor(mode)
|
283 |
+
self.is_for_online_eval = is_for_online_eval
|
284 |
+
if config.use_shared_dict:
|
285 |
+
self.reader = CachedReader(config.shared_dict)
|
286 |
+
else:
|
287 |
+
self.reader = ImReader()
|
288 |
+
|
289 |
+
def postprocess(self, sample):
|
290 |
+
return sample
|
291 |
+
|
292 |
+
def __getitem__(self, idx):
|
293 |
+
sample_path = self.filenames[idx]
|
294 |
+
focal = float(sample_path.split()[2])
|
295 |
+
sample = {}
|
296 |
+
|
297 |
+
if self.mode == 'train':
|
298 |
+
if self.config.dataset == 'kitti' and self.config.use_right and random.random() > 0.5:
|
299 |
+
image_path = os.path.join(
|
300 |
+
self.config.data_path, remove_leading_slash(sample_path.split()[3]))
|
301 |
+
depth_path = os.path.join(
|
302 |
+
self.config.gt_path, remove_leading_slash(sample_path.split()[4]))
|
303 |
+
else:
|
304 |
+
image_path = os.path.join(
|
305 |
+
self.config.data_path, remove_leading_slash(sample_path.split()[0]))
|
306 |
+
depth_path = os.path.join(
|
307 |
+
self.config.gt_path, remove_leading_slash(sample_path.split()[1]))
|
308 |
+
|
309 |
+
image = self.reader.open(image_path)
|
310 |
+
depth_gt = self.reader.open(depth_path)
|
311 |
+
w, h = image.size
|
312 |
+
|
313 |
+
if self.config.do_kb_crop:
|
314 |
+
height = image.height
|
315 |
+
width = image.width
|
316 |
+
top_margin = int(height - 352)
|
317 |
+
left_margin = int((width - 1216) / 2)
|
318 |
+
depth_gt = depth_gt.crop(
|
319 |
+
(left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
320 |
+
image = image.crop(
|
321 |
+
(left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
322 |
+
|
323 |
+
# Avoid blank boundaries due to pixel registration?
|
324 |
+
# Train images have white border. Test images have black border.
|
325 |
+
if self.config.dataset == 'nyu' and self.config.avoid_boundary:
|
326 |
+
# print("Avoiding Blank Boundaries!")
|
327 |
+
# We just crop and pad again with reflect padding to original size
|
328 |
+
# original_size = image.size
|
329 |
+
crop_params = get_white_border(np.array(image, dtype=np.uint8))
|
330 |
+
image = image.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom))
|
331 |
+
depth_gt = depth_gt.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom))
|
332 |
+
|
333 |
+
# Use reflect padding to fill the blank
|
334 |
+
image = np.array(image)
|
335 |
+
image = np.pad(image, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right), (0, 0)), mode='reflect')
|
336 |
+
image = Image.fromarray(image)
|
337 |
+
|
338 |
+
depth_gt = np.array(depth_gt)
|
339 |
+
depth_gt = np.pad(depth_gt, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right)), 'constant', constant_values=0)
|
340 |
+
depth_gt = Image.fromarray(depth_gt)
|
341 |
+
|
342 |
+
|
343 |
+
if self.config.do_random_rotate and (self.config.aug):
|
344 |
+
random_angle = (random.random() - 0.5) * 2 * self.config.degree
|
345 |
+
image = self.rotate_image(image, random_angle)
|
346 |
+
depth_gt = self.rotate_image(
|
347 |
+
depth_gt, random_angle, flag=Image.NEAREST)
|
348 |
+
|
349 |
+
image = np.asarray(image, dtype=np.float32) / 255.0
|
350 |
+
depth_gt = np.asarray(depth_gt, dtype=np.float32)
|
351 |
+
depth_gt = np.expand_dims(depth_gt, axis=2)
|
352 |
+
|
353 |
+
if self.config.dataset == 'nyu':
|
354 |
+
depth_gt = depth_gt / 1000.0
|
355 |
+
else:
|
356 |
+
depth_gt = depth_gt / 256.0
|
357 |
+
|
358 |
+
if self.config.aug and (self.config.random_crop):
|
359 |
+
image, depth_gt = self.random_crop(
|
360 |
+
image, depth_gt, self.config.input_height, self.config.input_width)
|
361 |
+
|
362 |
+
if self.config.aug and self.config.random_translate:
|
363 |
+
# print("Random Translation!")
|
364 |
+
image, depth_gt = self.random_translate(image, depth_gt, self.config.max_translation)
|
365 |
+
|
366 |
+
image, depth_gt = self.train_preprocess(image, depth_gt)
|
367 |
+
mask = np.logical_and(depth_gt > self.config.min_depth,
|
368 |
+
depth_gt < self.config.max_depth).squeeze()[None, ...]
|
369 |
+
sample = {'image': image, 'depth': depth_gt, 'focal': focal,
|
370 |
+
'mask': mask, **sample}
|
371 |
+
|
372 |
+
else:
|
373 |
+
if self.mode == 'online_eval':
|
374 |
+
data_path = self.config.data_path_eval
|
375 |
+
else:
|
376 |
+
data_path = self.config.data_path
|
377 |
+
|
378 |
+
image_path = os.path.join(
|
379 |
+
data_path, remove_leading_slash(sample_path.split()[0]))
|
380 |
+
image = np.asarray(self.reader.open(image_path),
|
381 |
+
dtype=np.float32) / 255.0
|
382 |
+
|
383 |
+
if self.mode == 'online_eval':
|
384 |
+
gt_path = self.config.gt_path_eval
|
385 |
+
depth_path = os.path.join(
|
386 |
+
gt_path, remove_leading_slash(sample_path.split()[1]))
|
387 |
+
has_valid_depth = False
|
388 |
+
try:
|
389 |
+
depth_gt = self.reader.open(depth_path)
|
390 |
+
has_valid_depth = True
|
391 |
+
except IOError:
|
392 |
+
depth_gt = False
|
393 |
+
# print('Missing gt for {}'.format(image_path))
|
394 |
+
|
395 |
+
if has_valid_depth:
|
396 |
+
depth_gt = np.asarray(depth_gt, dtype=np.float32)
|
397 |
+
depth_gt = np.expand_dims(depth_gt, axis=2)
|
398 |
+
if self.config.dataset == 'nyu':
|
399 |
+
depth_gt = depth_gt / 1000.0
|
400 |
+
else:
|
401 |
+
depth_gt = depth_gt / 256.0
|
402 |
+
|
403 |
+
mask = np.logical_and(
|
404 |
+
depth_gt >= self.config.min_depth, depth_gt <= self.config.max_depth).squeeze()[None, ...]
|
405 |
+
else:
|
406 |
+
mask = False
|
407 |
+
|
408 |
+
if self.config.do_kb_crop:
|
409 |
+
height = image.shape[0]
|
410 |
+
width = image.shape[1]
|
411 |
+
top_margin = int(height - 352)
|
412 |
+
left_margin = int((width - 1216) / 2)
|
413 |
+
image = image[top_margin:top_margin + 352,
|
414 |
+
left_margin:left_margin + 1216, :]
|
415 |
+
if self.mode == 'online_eval' and has_valid_depth:
|
416 |
+
depth_gt = depth_gt[top_margin:top_margin +
|
417 |
+
352, left_margin:left_margin + 1216, :]
|
418 |
+
|
419 |
+
if self.mode == 'online_eval':
|
420 |
+
sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth,
|
421 |
+
'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1],
|
422 |
+
'mask': mask}
|
423 |
+
else:
|
424 |
+
sample = {'image': image, 'focal': focal}
|
425 |
+
|
426 |
+
if (self.mode == 'train') or ('has_valid_depth' in sample and sample['has_valid_depth']):
|
427 |
+
mask = np.logical_and(depth_gt > self.config.min_depth,
|
428 |
+
depth_gt < self.config.max_depth).squeeze()[None, ...]
|
429 |
+
sample['mask'] = mask
|
430 |
+
|
431 |
+
if self.transform:
|
432 |
+
sample = self.transform(sample)
|
433 |
+
|
434 |
+
sample = self.postprocess(sample)
|
435 |
+
sample['dataset'] = self.config.dataset
|
436 |
+
sample = {**sample, 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1]}
|
437 |
+
|
438 |
+
return sample
|
439 |
+
|
440 |
+
def rotate_image(self, image, angle, flag=Image.BILINEAR):
|
441 |
+
result = image.rotate(angle, resample=flag)
|
442 |
+
return result
|
443 |
+
|
444 |
+
def random_crop(self, img, depth, height, width):
|
445 |
+
assert img.shape[0] >= height
|
446 |
+
assert img.shape[1] >= width
|
447 |
+
assert img.shape[0] == depth.shape[0]
|
448 |
+
assert img.shape[1] == depth.shape[1]
|
449 |
+
x = random.randint(0, img.shape[1] - width)
|
450 |
+
y = random.randint(0, img.shape[0] - height)
|
451 |
+
img = img[y:y + height, x:x + width, :]
|
452 |
+
depth = depth[y:y + height, x:x + width, :]
|
453 |
+
|
454 |
+
return img, depth
|
455 |
+
|
456 |
+
def random_translate(self, img, depth, max_t=20):
|
457 |
+
assert img.shape[0] == depth.shape[0]
|
458 |
+
assert img.shape[1] == depth.shape[1]
|
459 |
+
p = self.config.translate_prob
|
460 |
+
do_translate = random.random()
|
461 |
+
if do_translate > p:
|
462 |
+
return img, depth
|
463 |
+
x = random.randint(-max_t, max_t)
|
464 |
+
y = random.randint(-max_t, max_t)
|
465 |
+
M = np.float32([[1, 0, x], [0, 1, y]])
|
466 |
+
# print(img.shape, depth.shape)
|
467 |
+
img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
|
468 |
+
depth = cv2.warpAffine(depth, M, (depth.shape[1], depth.shape[0]))
|
469 |
+
depth = depth.squeeze()[..., None] # add channel dim back. Affine warp removes it
|
470 |
+
# print("after", img.shape, depth.shape)
|
471 |
+
return img, depth
|
472 |
+
|
473 |
+
def train_preprocess(self, image, depth_gt):
|
474 |
+
if self.config.aug:
|
475 |
+
# Random flipping
|
476 |
+
do_flip = random.random()
|
477 |
+
if do_flip > 0.5:
|
478 |
+
image = (image[:, ::-1, :]).copy()
|
479 |
+
depth_gt = (depth_gt[:, ::-1, :]).copy()
|
480 |
+
|
481 |
+
# Random gamma, brightness, color augmentation
|
482 |
+
do_augment = random.random()
|
483 |
+
if do_augment > 0.5:
|
484 |
+
image = self.augment_image(image)
|
485 |
+
|
486 |
+
return image, depth_gt
|
487 |
+
|
488 |
+
def augment_image(self, image):
|
489 |
+
# gamma augmentation
|
490 |
+
gamma = random.uniform(0.9, 1.1)
|
491 |
+
image_aug = image ** gamma
|
492 |
+
|
493 |
+
# brightness augmentation
|
494 |
+
if self.config.dataset == 'nyu':
|
495 |
+
brightness = random.uniform(0.75, 1.25)
|
496 |
+
else:
|
497 |
+
brightness = random.uniform(0.9, 1.1)
|
498 |
+
image_aug = image_aug * brightness
|
499 |
+
|
500 |
+
# color augmentation
|
501 |
+
colors = np.random.uniform(0.9, 1.1, size=3)
|
502 |
+
white = np.ones((image.shape[0], image.shape[1]))
|
503 |
+
color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
|
504 |
+
image_aug *= color_image
|
505 |
+
image_aug = np.clip(image_aug, 0, 1)
|
506 |
+
|
507 |
+
return image_aug
|
508 |
+
|
509 |
+
def __len__(self):
|
510 |
+
return len(self.filenames)
|
511 |
+
|
512 |
+
|
513 |
+
class ToTensor(object):
|
514 |
+
def __init__(self, mode, do_normalize=False, size=None):
|
515 |
+
self.mode = mode
|
516 |
+
self.normalize = transforms.Normalize(
|
517 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if do_normalize else nn.Identity()
|
518 |
+
self.size = size
|
519 |
+
if size is not None:
|
520 |
+
self.resize = transforms.Resize(size=size)
|
521 |
+
else:
|
522 |
+
self.resize = nn.Identity()
|
523 |
+
|
524 |
+
def __call__(self, sample):
|
525 |
+
image, focal = sample['image'], sample['focal']
|
526 |
+
image = self.to_tensor(image)
|
527 |
+
image = self.normalize(image)
|
528 |
+
image = self.resize(image)
|
529 |
+
|
530 |
+
if self.mode == 'test':
|
531 |
+
return {'image': image, 'focal': focal}
|
532 |
+
|
533 |
+
depth = sample['depth']
|
534 |
+
if self.mode == 'train':
|
535 |
+
depth = self.to_tensor(depth)
|
536 |
+
return {**sample, 'image': image, 'depth': depth, 'focal': focal}
|
537 |
+
else:
|
538 |
+
has_valid_depth = sample['has_valid_depth']
|
539 |
+
image = self.resize(image)
|
540 |
+
return {**sample, 'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth,
|
541 |
+
'image_path': sample['image_path'], 'depth_path': sample['depth_path']}
|
542 |
+
|
543 |
+
def to_tensor(self, pic):
|
544 |
+
if not (_is_pil_image(pic) or _is_numpy_image(pic)):
|
545 |
+
raise TypeError(
|
546 |
+
'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
|
547 |
+
|
548 |
+
if isinstance(pic, np.ndarray):
|
549 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
550 |
+
return img
|
551 |
+
|
552 |
+
# handle PIL Image
|
553 |
+
if pic.mode == 'I':
|
554 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
555 |
+
elif pic.mode == 'I;16':
|
556 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
557 |
+
else:
|
558 |
+
img = torch.ByteTensor(
|
559 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
560 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
561 |
+
if pic.mode == 'YCbCr':
|
562 |
+
nchannel = 3
|
563 |
+
elif pic.mode == 'I;16':
|
564 |
+
nchannel = 1
|
565 |
+
else:
|
566 |
+
nchannel = len(pic.mode)
|
567 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
568 |
+
|
569 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
570 |
+
if isinstance(img, torch.ByteTensor):
|
571 |
+
return img.float()
|
572 |
+
else:
|
573 |
+
return img
|
src/flux/annotator/zoe/zoedepth/data/ddad.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import os
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
from PIL import Image
|
30 |
+
from torch.utils.data import DataLoader, Dataset
|
31 |
+
from torchvision import transforms
|
32 |
+
|
33 |
+
|
34 |
+
class ToTensor(object):
|
35 |
+
def __init__(self, resize_shape):
|
36 |
+
# self.normalize = transforms.Normalize(
|
37 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
38 |
+
self.normalize = lambda x : x
|
39 |
+
self.resize = transforms.Resize(resize_shape)
|
40 |
+
|
41 |
+
def __call__(self, sample):
|
42 |
+
image, depth = sample['image'], sample['depth']
|
43 |
+
image = self.to_tensor(image)
|
44 |
+
image = self.normalize(image)
|
45 |
+
depth = self.to_tensor(depth)
|
46 |
+
|
47 |
+
image = self.resize(image)
|
48 |
+
|
49 |
+
return {'image': image, 'depth': depth, 'dataset': "ddad"}
|
50 |
+
|
51 |
+
def to_tensor(self, pic):
|
52 |
+
|
53 |
+
if isinstance(pic, np.ndarray):
|
54 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
55 |
+
return img
|
56 |
+
|
57 |
+
# # handle PIL Image
|
58 |
+
if pic.mode == 'I':
|
59 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
60 |
+
elif pic.mode == 'I;16':
|
61 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
62 |
+
else:
|
63 |
+
img = torch.ByteTensor(
|
64 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
65 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
66 |
+
if pic.mode == 'YCbCr':
|
67 |
+
nchannel = 3
|
68 |
+
elif pic.mode == 'I;16':
|
69 |
+
nchannel = 1
|
70 |
+
else:
|
71 |
+
nchannel = len(pic.mode)
|
72 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
73 |
+
|
74 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
75 |
+
|
76 |
+
if isinstance(img, torch.ByteTensor):
|
77 |
+
return img.float()
|
78 |
+
else:
|
79 |
+
return img
|
80 |
+
|
81 |
+
|
82 |
+
class DDAD(Dataset):
|
83 |
+
def __init__(self, data_dir_root, resize_shape):
|
84 |
+
import glob
|
85 |
+
|
86 |
+
# image paths are of the form <data_dir_root>/{outleft, depthmap}/*.png
|
87 |
+
self.image_files = glob.glob(os.path.join(data_dir_root, '*.png'))
|
88 |
+
self.depth_files = [r.replace("_rgb.png", "_depth.npy")
|
89 |
+
for r in self.image_files]
|
90 |
+
self.transform = ToTensor(resize_shape)
|
91 |
+
|
92 |
+
def __getitem__(self, idx):
|
93 |
+
|
94 |
+
image_path = self.image_files[idx]
|
95 |
+
depth_path = self.depth_files[idx]
|
96 |
+
|
97 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
98 |
+
depth = np.load(depth_path) # meters
|
99 |
+
|
100 |
+
# depth[depth > 8] = -1
|
101 |
+
depth = depth[..., None]
|
102 |
+
|
103 |
+
sample = dict(image=image, depth=depth)
|
104 |
+
sample = self.transform(sample)
|
105 |
+
|
106 |
+
if idx == 0:
|
107 |
+
print(sample["image"].shape)
|
108 |
+
|
109 |
+
return sample
|
110 |
+
|
111 |
+
def __len__(self):
|
112 |
+
return len(self.image_files)
|
113 |
+
|
114 |
+
|
115 |
+
def get_ddad_loader(data_dir_root, resize_shape, batch_size=1, **kwargs):
|
116 |
+
dataset = DDAD(data_dir_root, resize_shape)
|
117 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
src/flux/annotator/zoe/zoedepth/data/diml_indoor_test.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import os
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
from PIL import Image
|
30 |
+
from torch.utils.data import DataLoader, Dataset
|
31 |
+
from torchvision import transforms
|
32 |
+
|
33 |
+
|
34 |
+
class ToTensor(object):
|
35 |
+
def __init__(self):
|
36 |
+
# self.normalize = transforms.Normalize(
|
37 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
38 |
+
self.normalize = lambda x : x
|
39 |
+
self.resize = transforms.Resize((480, 640))
|
40 |
+
|
41 |
+
def __call__(self, sample):
|
42 |
+
image, depth = sample['image'], sample['depth']
|
43 |
+
image = self.to_tensor(image)
|
44 |
+
image = self.normalize(image)
|
45 |
+
depth = self.to_tensor(depth)
|
46 |
+
|
47 |
+
image = self.resize(image)
|
48 |
+
|
49 |
+
return {'image': image, 'depth': depth, 'dataset': "diml_indoor"}
|
50 |
+
|
51 |
+
def to_tensor(self, pic):
|
52 |
+
|
53 |
+
if isinstance(pic, np.ndarray):
|
54 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
55 |
+
return img
|
56 |
+
|
57 |
+
# # handle PIL Image
|
58 |
+
if pic.mode == 'I':
|
59 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
60 |
+
elif pic.mode == 'I;16':
|
61 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
62 |
+
else:
|
63 |
+
img = torch.ByteTensor(
|
64 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
65 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
66 |
+
if pic.mode == 'YCbCr':
|
67 |
+
nchannel = 3
|
68 |
+
elif pic.mode == 'I;16':
|
69 |
+
nchannel = 1
|
70 |
+
else:
|
71 |
+
nchannel = len(pic.mode)
|
72 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
73 |
+
|
74 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
75 |
+
if isinstance(img, torch.ByteTensor):
|
76 |
+
return img.float()
|
77 |
+
else:
|
78 |
+
return img
|
79 |
+
|
80 |
+
|
81 |
+
class DIML_Indoor(Dataset):
|
82 |
+
def __init__(self, data_dir_root):
|
83 |
+
import glob
|
84 |
+
|
85 |
+
# image paths are of the form <data_dir_root>/{HR, LR}/<scene>/{color, depth_filled}/*.png
|
86 |
+
self.image_files = glob.glob(os.path.join(
|
87 |
+
data_dir_root, "LR", '*', 'color', '*.png'))
|
88 |
+
self.depth_files = [r.replace("color", "depth_filled").replace(
|
89 |
+
"_c.png", "_depth_filled.png") for r in self.image_files]
|
90 |
+
self.transform = ToTensor()
|
91 |
+
|
92 |
+
def __getitem__(self, idx):
|
93 |
+
image_path = self.image_files[idx]
|
94 |
+
depth_path = self.depth_files[idx]
|
95 |
+
|
96 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
97 |
+
depth = np.asarray(Image.open(depth_path),
|
98 |
+
dtype='uint16') / 1000.0 # mm to meters
|
99 |
+
|
100 |
+
# print(np.shape(image))
|
101 |
+
# print(np.shape(depth))
|
102 |
+
|
103 |
+
# depth[depth > 8] = -1
|
104 |
+
depth = depth[..., None]
|
105 |
+
|
106 |
+
sample = dict(image=image, depth=depth)
|
107 |
+
|
108 |
+
# return sample
|
109 |
+
sample = self.transform(sample)
|
110 |
+
|
111 |
+
if idx == 0:
|
112 |
+
print(sample["image"].shape)
|
113 |
+
|
114 |
+
return sample
|
115 |
+
|
116 |
+
def __len__(self):
|
117 |
+
return len(self.image_files)
|
118 |
+
|
119 |
+
|
120 |
+
def get_diml_indoor_loader(data_dir_root, batch_size=1, **kwargs):
|
121 |
+
dataset = DIML_Indoor(data_dir_root)
|
122 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
123 |
+
|
124 |
+
# get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/HR")
|
125 |
+
# get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/LR")
|
src/flux/annotator/zoe/zoedepth/data/diml_outdoor_test.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import os
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
from PIL import Image
|
30 |
+
from torch.utils.data import DataLoader, Dataset
|
31 |
+
from torchvision import transforms
|
32 |
+
|
33 |
+
|
34 |
+
class ToTensor(object):
|
35 |
+
def __init__(self):
|
36 |
+
# self.normalize = transforms.Normalize(
|
37 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
38 |
+
self.normalize = lambda x : x
|
39 |
+
|
40 |
+
def __call__(self, sample):
|
41 |
+
image, depth = sample['image'], sample['depth']
|
42 |
+
image = self.to_tensor(image)
|
43 |
+
image = self.normalize(image)
|
44 |
+
depth = self.to_tensor(depth)
|
45 |
+
|
46 |
+
return {'image': image, 'depth': depth, 'dataset': "diml_outdoor"}
|
47 |
+
|
48 |
+
def to_tensor(self, pic):
|
49 |
+
|
50 |
+
if isinstance(pic, np.ndarray):
|
51 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
52 |
+
return img
|
53 |
+
|
54 |
+
# # handle PIL Image
|
55 |
+
if pic.mode == 'I':
|
56 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
57 |
+
elif pic.mode == 'I;16':
|
58 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
59 |
+
else:
|
60 |
+
img = torch.ByteTensor(
|
61 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
62 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
63 |
+
if pic.mode == 'YCbCr':
|
64 |
+
nchannel = 3
|
65 |
+
elif pic.mode == 'I;16':
|
66 |
+
nchannel = 1
|
67 |
+
else:
|
68 |
+
nchannel = len(pic.mode)
|
69 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
70 |
+
|
71 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
72 |
+
if isinstance(img, torch.ByteTensor):
|
73 |
+
return img.float()
|
74 |
+
else:
|
75 |
+
return img
|
76 |
+
|
77 |
+
|
78 |
+
class DIML_Outdoor(Dataset):
|
79 |
+
def __init__(self, data_dir_root):
|
80 |
+
import glob
|
81 |
+
|
82 |
+
# image paths are of the form <data_dir_root>/{outleft, depthmap}/*.png
|
83 |
+
self.image_files = glob.glob(os.path.join(
|
84 |
+
data_dir_root, "*", 'outleft', '*.png'))
|
85 |
+
self.depth_files = [r.replace("outleft", "depthmap")
|
86 |
+
for r in self.image_files]
|
87 |
+
self.transform = ToTensor()
|
88 |
+
|
89 |
+
def __getitem__(self, idx):
|
90 |
+
image_path = self.image_files[idx]
|
91 |
+
depth_path = self.depth_files[idx]
|
92 |
+
|
93 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
94 |
+
depth = np.asarray(Image.open(depth_path),
|
95 |
+
dtype='uint16') / 1000.0 # mm to meters
|
96 |
+
|
97 |
+
# depth[depth > 8] = -1
|
98 |
+
depth = depth[..., None]
|
99 |
+
|
100 |
+
sample = dict(image=image, depth=depth, dataset="diml_outdoor")
|
101 |
+
|
102 |
+
# return sample
|
103 |
+
return self.transform(sample)
|
104 |
+
|
105 |
+
def __len__(self):
|
106 |
+
return len(self.image_files)
|
107 |
+
|
108 |
+
|
109 |
+
def get_diml_outdoor_loader(data_dir_root, batch_size=1, **kwargs):
|
110 |
+
dataset = DIML_Outdoor(data_dir_root)
|
111 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
112 |
+
|
113 |
+
# get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/HR")
|
114 |
+
# get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/LR")
|
src/flux/annotator/zoe/zoedepth/data/diode.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import os
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
from PIL import Image
|
30 |
+
from torch.utils.data import DataLoader, Dataset
|
31 |
+
from torchvision import transforms
|
32 |
+
|
33 |
+
|
34 |
+
class ToTensor(object):
|
35 |
+
def __init__(self):
|
36 |
+
# self.normalize = transforms.Normalize(
|
37 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
38 |
+
self.normalize = lambda x : x
|
39 |
+
self.resize = transforms.Resize(480)
|
40 |
+
|
41 |
+
def __call__(self, sample):
|
42 |
+
image, depth = sample['image'], sample['depth']
|
43 |
+
image = self.to_tensor(image)
|
44 |
+
image = self.normalize(image)
|
45 |
+
depth = self.to_tensor(depth)
|
46 |
+
|
47 |
+
image = self.resize(image)
|
48 |
+
|
49 |
+
return {'image': image, 'depth': depth, 'dataset': "diode"}
|
50 |
+
|
51 |
+
def to_tensor(self, pic):
|
52 |
+
|
53 |
+
if isinstance(pic, np.ndarray):
|
54 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
55 |
+
return img
|
56 |
+
|
57 |
+
# # handle PIL Image
|
58 |
+
if pic.mode == 'I':
|
59 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
60 |
+
elif pic.mode == 'I;16':
|
61 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
62 |
+
else:
|
63 |
+
img = torch.ByteTensor(
|
64 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
65 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
66 |
+
if pic.mode == 'YCbCr':
|
67 |
+
nchannel = 3
|
68 |
+
elif pic.mode == 'I;16':
|
69 |
+
nchannel = 1
|
70 |
+
else:
|
71 |
+
nchannel = len(pic.mode)
|
72 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
73 |
+
|
74 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
75 |
+
|
76 |
+
if isinstance(img, torch.ByteTensor):
|
77 |
+
return img.float()
|
78 |
+
else:
|
79 |
+
return img
|
80 |
+
|
81 |
+
|
82 |
+
class DIODE(Dataset):
|
83 |
+
def __init__(self, data_dir_root):
|
84 |
+
import glob
|
85 |
+
|
86 |
+
# image paths are of the form <data_dir_root>/scene_#/scan_#/*.png
|
87 |
+
self.image_files = glob.glob(
|
88 |
+
os.path.join(data_dir_root, '*', '*', '*.png'))
|
89 |
+
self.depth_files = [r.replace(".png", "_depth.npy")
|
90 |
+
for r in self.image_files]
|
91 |
+
self.depth_mask_files = [
|
92 |
+
r.replace(".png", "_depth_mask.npy") for r in self.image_files]
|
93 |
+
self.transform = ToTensor()
|
94 |
+
|
95 |
+
def __getitem__(self, idx):
|
96 |
+
image_path = self.image_files[idx]
|
97 |
+
depth_path = self.depth_files[idx]
|
98 |
+
depth_mask_path = self.depth_mask_files[idx]
|
99 |
+
|
100 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
101 |
+
depth = np.load(depth_path) # in meters
|
102 |
+
valid = np.load(depth_mask_path) # binary
|
103 |
+
|
104 |
+
# depth[depth > 8] = -1
|
105 |
+
# depth = depth[..., None]
|
106 |
+
|
107 |
+
sample = dict(image=image, depth=depth, valid=valid)
|
108 |
+
|
109 |
+
# return sample
|
110 |
+
sample = self.transform(sample)
|
111 |
+
|
112 |
+
if idx == 0:
|
113 |
+
print(sample["image"].shape)
|
114 |
+
|
115 |
+
return sample
|
116 |
+
|
117 |
+
def __len__(self):
|
118 |
+
return len(self.image_files)
|
119 |
+
|
120 |
+
|
121 |
+
def get_diode_loader(data_dir_root, batch_size=1, **kwargs):
|
122 |
+
dataset = DIODE(data_dir_root)
|
123 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
124 |
+
|
125 |
+
# get_diode_loader(data_dir_root="datasets/diode/val/outdoor")
|