bol commited on
Commit
99738e0
·
1 Parent(s): e493783
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +1 -0
  3. .gitignore +29 -0
  4. README.md +8 -6
  5. app.py +163 -144
  6. app_old.py +374 -0
  7. image_datasets/.DS_Store +0 -0
  8. image_datasets/dataset.py +231 -0
  9. inference_configs/inference.yaml +33 -0
  10. requirements.txt +25 -5
  11. src/.DS_Store +0 -0
  12. src/flux/.DS_Store +0 -0
  13. src/flux/__init__.py +11 -0
  14. src/flux/__main__.py +4 -0
  15. src/flux/annotator/canny/__init__.py +6 -0
  16. src/flux/annotator/ckpts/ckpts.txt +1 -0
  17. src/flux/annotator/dwpose/__init__.py +68 -0
  18. src/flux/annotator/dwpose/onnxdet.py +125 -0
  19. src/flux/annotator/dwpose/onnxpose.py +360 -0
  20. src/flux/annotator/dwpose/util.py +297 -0
  21. src/flux/annotator/dwpose/wholebody.py +48 -0
  22. src/flux/annotator/hed/__init__.py +95 -0
  23. src/flux/annotator/midas/LICENSE +21 -0
  24. src/flux/annotator/midas/__init__.py +42 -0
  25. src/flux/annotator/midas/api.py +168 -0
  26. src/flux/annotator/midas/midas/__init__.py +0 -0
  27. src/flux/annotator/midas/midas/base_model.py +16 -0
  28. src/flux/annotator/midas/midas/blocks.py +342 -0
  29. src/flux/annotator/midas/midas/dpt_depth.py +109 -0
  30. src/flux/annotator/midas/midas/midas_net.py +76 -0
  31. src/flux/annotator/midas/midas/midas_net_custom.py +128 -0
  32. src/flux/annotator/midas/midas/transforms.py +234 -0
  33. src/flux/annotator/midas/midas/vit.py +491 -0
  34. src/flux/annotator/midas/utils.py +189 -0
  35. src/flux/annotator/mlsd/LICENSE +201 -0
  36. src/flux/annotator/mlsd/__init__.py +40 -0
  37. src/flux/annotator/mlsd/models/mbv2_mlsd_large.py +292 -0
  38. src/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py +275 -0
  39. src/flux/annotator/mlsd/utils.py +580 -0
  40. src/flux/annotator/tile/__init__.py +26 -0
  41. src/flux/annotator/tile/guided_filter.py +280 -0
  42. src/flux/annotator/util.py +38 -0
  43. src/flux/annotator/zoe/LICENSE +21 -0
  44. src/flux/annotator/zoe/__init__.py +48 -0
  45. src/flux/annotator/zoe/zoedepth/data/__init__.py +24 -0
  46. src/flux/annotator/zoe/zoedepth/data/data_mono.py +573 -0
  47. src/flux/annotator/zoe/zoedepth/data/ddad.py +117 -0
  48. src/flux/annotator/zoe/zoedepth/data/diml_indoor_test.py +125 -0
  49. src/flux/annotator/zoe/zoedepth/data/diml_outdoor_test.py +114 -0
  50. src/flux/annotator/zoe/zoedepth/data/diode.py +125 -0
.DS_Store ADDED
Binary file (8.2 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/**/* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ assets/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ pip-wheel-metadata/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: Teset Demo
3
- emoji: 🖼
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.25.2
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: ByteMorph Demo
3
+ emoji: 📊
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.31.0
8
  app_file: app.py
9
  pinned: false
10
+ license: other
11
+ short_description: Online Demo for ByteMorph
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,154 +1,173 @@
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
8
-
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  """
66
 
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
  )
 
 
118
 
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
 
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
1
  import gradio as gr
 
 
 
 
 
2
  import torch
3
+ import spaces
4
+ import os
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ from huggingface_hub import hf_hub_download
9
+ from safetensors.torch import load_file
10
+ from omegaconf import OmegaConf
11
+ from src.flux.util import load_ae, load_clip, load_flow_model2, load_t5, tensor_to_pil_image
12
+ from src.flux.xflux_pipeline import XFluxSampler
13
+ from image_datasets.dataset import image_resize
14
+
15
+ # ===== No CUDA/model initialization globally =====
16
+ args = OmegaConf.load("inference_configs/inference.yaml")
17
+ is_schnell = args.model_name == "flux-schnell"
18
+
19
+ # sampler = None
20
+ device = torch.device("cuda")
21
+ dtype = torch.bfloat16
22
+ dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype)
23
+ vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
24
+ t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype)
25
+ clip = load_clip("cpu").to(device, dtype=dtype)
26
+
27
+ vae.requires_grad_(False)
28
+ t5.requires_grad_(False)
29
+ clip.requires_grad_(False)
30
+
31
+ model_path = hf_hub_download(
32
+ repo_id="Boese0601/ByteMorpher",
33
+ filename="dit.safetensors",
34
+ use_auth_token=os.getenv("HF_TOKEN")
35
+ )
36
+ state_dict = load_file(model_path)
37
+ dit.load_state_dict(state_dict)
38
+ dit.eval()
39
+ dit.to(device, dtype=dtype)
40
+
41
+ sampler = XFluxSampler(
42
+ clip=clip,
43
+ t5=t5,
44
+ ae=vae,
45
+ model=dit,
46
+ device=device,
47
+ ip_loaded=False,
48
+ spatial_condition=False,
49
+ clip_image_processor=None,
50
+ image_encoder=None,
51
+ improj=None
52
+ )
53
+ #test push
54
+ @spaces.GPU
55
+ def generate(image: Image.Image, edit_prompt: str):
56
+ # global sampler
57
+ # device = torch.device("cuda")
58
+ # dtype = torch.bfloat16
59
+
60
+ # if sampler is None:
61
+ # dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype)
62
+ # vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
63
+ # t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype)
64
+ # clip = load_clip("cpu").to(device, dtype=dtype)
65
+
66
+ # vae.requires_grad_(False)
67
+ # t5.requires_grad_(False)
68
+ # clip.requires_grad_(False)
69
+
70
+ # model_path = hf_hub_download(
71
+ # repo_id="Boese0601/ByteMorpher",
72
+ # filename="dit.safetensors",
73
+ # use_auth_token=os.getenv("HF_TOKEN")
74
+ # )
75
+ # state_dict = load_file(model_path)
76
+ # dit.load_state_dict(state_dict)
77
+ # dit.eval()
78
+
79
+ # sampler = XFluxSampler(
80
+ # clip=clip,
81
+ # t5=t5,
82
+ # ae=vae,
83
+ # model=dit,
84
+ # device=device,
85
+ # ip_loaded=False,
86
+ # spatial_condition=False,
87
+ # clip_image_processor=None,
88
+ # image_encoder=None,
89
+ # improj=None
90
+ # )
91
+
92
+ img = image_resize(image, 512)
93
+ w, h = img.size
94
+ img = img.resize(((w // 32) * 32, (h // 32) * 32))
95
+ img = torch.from_numpy((np.array(img) / 127.5) - 1)
96
+ img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype)
97
+
98
+ with torch.no_grad():
99
+ result = sampler(
100
+ prompt=edit_prompt,
101
+ width=args.sample_width,
102
+ height=args.sample_height,
103
+ num_steps=args.sample_steps,
104
+ image_prompt=None,
105
+ true_gs=args.cfg_scale,
106
+ seed=args.seed,
107
+ ip_scale=args.ip_scale if args.use_ip else 1.0,
108
+ source_image=img if args.use_spatial_condition else None,
109
+ )
110
+ return tensor_to_pil_image(result)
111
+
112
+ def get_samples():
113
+ sample_list = [
114
+ {
115
+ "image": "assets/0_camera_zoom/20486354.png",
116
+ "edit_prompt": "Zoom in on the coral and add a small blue fish in the background.",
117
+ },
118
+ ]
119
+ return [
120
+ [
121
+ Image.open(sample["image"]).resize((512, 512)),
122
+ sample["edit_prompt"],
123
+ ]
124
+ for sample in sample_list
125
+ ]
126
+
127
+ header = """
128
+ # ByteMorph
129
+
130
+ <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
131
+ <a href=""><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
132
+ <a href="https://huggingface.co/datasets/Boese0601/ByteMorph-Bench"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
133
+ <a href="https://github.com/Boese0601/ByteMorph"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
134
+ </div>
135
  """
136
 
137
+ def create_app():
138
+ with gr.Blocks() as app:
139
+ gr.Markdown(header, elem_id="header")
140
+ with gr.Row(equal_height=False):
141
+ with gr.Column(variant="panel", elem_classes="inputPanel"):
142
+ original_image = gr.Image(
143
+ type="pil", label="Condition Image", width=300, elem_id="input"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  )
145
+ edit_prompt = gr.Textbox(lines=2, label="Edit Prompt", elem_id="edit_prompt")
146
+ submit_btn = gr.Button("Run", elem_id="submit_btn")
147
 
148
+ with gr.Column(variant="panel", elem_classes="outputPanel"):
149
+ output_image = gr.Image(type="pil", elem_id="output")
 
 
 
 
 
 
150
 
151
+ with gr.Row():
152
+ examples = gr.Examples(
153
+ examples=get_samples(),
154
+ inputs=[original_image, edit_prompt],
155
+ label="Examples",
156
+ )
 
157
 
158
+ submit_btn.click(
159
+ fn=generate,
160
+ inputs=[original_image, edit_prompt],
161
+ outputs=output_image,
162
+ )
163
+ gr.HTML(
164
+ """
165
+ <div style="text-align: center;">
166
+ * This demo's template was modified from <a href="https://arxiv.org/abs/2411.15098" target="_blank">OminiControl</a>.
167
+ </div>
168
+ """
169
+ )
170
+ return app
 
 
 
171
 
172
  if __name__ == "__main__":
173
+ create_app().launch(debug=False, share=False, ssr_mode=False)
app_old.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ # from src.condition import Condition
6
+ from diffusers.pipelines import FluxPipeline
7
+ import numpy as np
8
+ import requests
9
+ from huggingface_hub import hf_hub_download
10
+ from safetensors.torch import load_file
11
+ import torch.multiprocessing as mp
12
+ ###
13
+ import argparse
14
+ import logging
15
+ import math
16
+ import os
17
+ import re
18
+ import random
19
+ import shutil
20
+ from contextlib import nullcontext
21
+ from pathlib import Path
22
+ from PIL import Image
23
+ import accelerate
24
+ import datasets
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from torch import Tensor, nn
29
+ import torch.utils.checkpoint
30
+ import transformers
31
+ from accelerate import Accelerator
32
+ from accelerate.logging import get_logger
33
+ from accelerate.state import AcceleratorState
34
+ from accelerate.utils import ProjectConfiguration, set_seed
35
+ from huggingface_hub import create_repo, upload_folder
36
+ from packaging import version
37
+ from tqdm.auto import tqdm
38
+ from transformers import CLIPTextModel, CLIPTokenizer
39
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
40
+ from transformers.utils import ContextManagers
41
+ from omegaconf import OmegaConf
42
+ from copy import deepcopy
43
+ import diffusers
44
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline
45
+ from diffusers.optimization import get_scheduler
46
+ from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr
47
+ from diffusers.utils import check_min_version, deprecate, make_image_grid
48
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
49
+ from diffusers.utils.import_utils import is_xformers_available
50
+ from diffusers.utils.torch_utils import is_compiled_module
51
+ from einops import rearrange
52
+ from src.flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
53
+ from src.flux.util import (configs, load_ae, load_clip,
54
+ load_flow_model2, load_t5, save_image, tensor_to_pil_image, load_checkpoint)
55
+ from src.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor, IPDoubleStreamBlockProcessor, IPSingleStreamBlockProcessor, ImageProjModel
56
+ from src.flux.xflux_pipeline import XFluxSampler
57
+
58
+ from image_datasets.dataset import loader, eval_image_pair_loader, image_resize
59
+
60
+ from safetensors.torch import load_file
61
+ import json
62
+
63
+
64
+ # logger = get_logger(__name__, log_level="INFO")
65
+
66
+
67
+ def get_models(name: str, device, offload: bool, is_schnell: bool):
68
+ t5 = load_t5(device, max_length=256 if is_schnell else 512)
69
+ clip = load_clip(device)
70
+ clip.requires_grad_(False)
71
+ model = load_flow_model2(name, device="cpu")
72
+ vae = load_ae(name, device="cpu" if offload else device)
73
+ return model, vae, t5, clip
74
+
75
+ args = OmegaConf.load("inference_configs/inference.yaml") #OmegaConf.load(parse_args())
76
+ is_schnell = args.model_name == "flux-schnell"
77
+ set_seed(args.seed)
78
+ # logging_dir = os.path.join(args.output_dir, args.logging_dir)
79
+ device = "cuda"
80
+ dit, vae, t5, clip = get_models(name=args.model_name, device=device, offload=False, is_schnell=is_schnell)
81
+
82
+ # # load image encoder
83
+ # ip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.getenv("CLIP_VIT")).to(
84
+ # # accelerator.device, dtype=torch.bfloat16
85
+ # device, dtype=torch.bfloat16
86
+ # )
87
+ # ip_clip_image_processor = CLIPImageProcessor()
88
+
89
+ if args.use_ip:
90
+ sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=True, spatial_condition=False, clip_image_processor=ip_clip_image_processor, image_encoder=ip_image_encoder, improj=ip_improj)
91
+ elif args.use_spatial_condition:
92
+ sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=False, spatial_condition=True, clip_image_processor=None, image_encoder=None, improj=None,share_position_embedding=args.share_position_embedding)
93
+ else:
94
+ sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=False, spatial_condition=False, clip_image_processor=None, image_encoder=None, improj=None)
95
+
96
+
97
+ # @spaces.GPU
98
+ def generate(image, edit_prompt):
99
+ print("hello?????????!!!!!")
100
+
101
+ # accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
102
+
103
+ # accelerator = Accelerator(
104
+ # gradient_accumulation_steps=1,
105
+ # mixed_precision=args.mixed_precision,
106
+ # log_with=args.report_to,
107
+ # project_config=accelerator_project_config,
108
+ # )
109
+
110
+ # Make one log on every process with the configuration for debugging.
111
+ # logging.basicConfig(
112
+ # format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
113
+ # datefmt="%m/%d/%Y %H:%M:%S",
114
+ # level=logging.INFO,
115
+ # )
116
+ # logger.info(accelerator.state, main_process_only=False)
117
+ # if accelerator.is_local_main_process:
118
+ # datasets.utils.logging.set_verbosity_warning()
119
+ # transformers.utils.logging.set_verbosity_warning()
120
+ # diffusers.utils.logging.set_verbosity_info()
121
+ # else:
122
+ # datasets.utils.logging.set_verbosity_error()
123
+ # transformers.utils.logging.set_verbosity_error()
124
+ # diffusers.utils.logging.set_verbosity_error()
125
+
126
+
127
+ # if accelerator.is_main_process:
128
+ # if args.output_dir is not None:
129
+ # os.makedirs(args.output_dir, exist_ok=True)
130
+ # gpt_eval_path = os.path.join(args.output_dir,"Eval")
131
+ # os.makedirs(gpt_eval_path, exist_ok=True)
132
+
133
+ # dit, vae, t5, clip = get_models(name=args.model_name, device=accelerator.device, offload=False, is_schnell=is_schnell)
134
+ # dit, vae, t5, clip = get_models(name=args.model_name, device=device, offload=False, is_schnell=is_schnell)
135
+
136
+ if args.use_lora:
137
+ lora_attn_procs = {}
138
+ if args.use_ip:
139
+ ip_attn_procs = {}
140
+ if args.double_blocks is None:
141
+ double_blocks_idx = list(range(19))
142
+ else:
143
+ double_blocks_idx = [int(idx) for idx in args.double_blocks.split(",")]
144
+
145
+ if args.single_blocks is None:
146
+ single_blocks_idx = list(range(38))
147
+ elif args.single_blocks is not None:
148
+ single_blocks_idx = [int(idx) for idx in args.single_blocks.split(",")]
149
+
150
+ if args.use_lora:
151
+ for name, attn_processor in dit.attn_processors.items():
152
+ match = re.search(r'\.(\d+)\.', name)
153
+ if match:
154
+ layer_index = int(match.group(1))
155
+
156
+ if name.startswith("double_blocks") and layer_index in double_blocks_idx:
157
+ # if accelerator.is_main_process:
158
+ # print("setting LoRA Processor for", name)
159
+ lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(
160
+ dim=3072, rank=args.rank
161
+ )
162
+ elif name.startswith("single_blocks") and layer_index in single_blocks_idx:
163
+ # if accelerator.is_main_process:
164
+ # print("setting LoRA Processor for", name)
165
+ lora_attn_procs[name] = SingleStreamBlockLoraProcessor(
166
+ dim=3072, rank=args.rank
167
+ )
168
+ else:
169
+ lora_attn_procs[name] = attn_processor
170
+
171
+ dit.set_attn_processor(lora_attn_procs)
172
+
173
+ # if args.use_ip:
174
+ # # unpack checkpoint
175
+ # checkpoint = load_checkpoint(args.ip_local_path, args.ip_repo_id, args.ip_name)
176
+ # prefix = "double_blocks."
177
+ # # blocks = {}
178
+ # proj = {}
179
+
180
+ # for key, value in checkpoint.items():
181
+ # # if key.startswith(prefix):
182
+ # # blocks[key[len(prefix):].replace('.processor.', '.')] = value
183
+ # if key.startswith("ip_adapter_proj_model"):
184
+ # proj[key[len("ip_adapter_proj_model."):]] = value
185
+
186
+ # # # load image encoder
187
+ # # ip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.getenv("CLIP_VIT")).to(
188
+ # # # accelerator.device, dtype=torch.bfloat16
189
+ # # device, dtype=torch.bfloat16
190
+ # # )
191
+ # # ip_clip_image_processor = CLIPImageProcessor()
192
+
193
+ # # setup image embedding projection model
194
+ # ip_improj = ImageProjModel(4096, 768, 4)
195
+ # ip_improj.load_state_dict(proj)
196
+ # # ip_improj = ip_improj.to(accelerator.device, dtype=torch.bfloat16)
197
+ # ip_improj = ip_improj.to(device, dtype=torch.bfloat16)
198
+
199
+ # ip_attn_procs = {}
200
+
201
+ # for name, _ in dit.attn_processors.items():
202
+ # ip_state_dict = {}
203
+ # for k in checkpoint.keys():
204
+ # if name in k:
205
+ # ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k]
206
+ # if ip_state_dict:
207
+ # ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072)
208
+ # ip_attn_procs[name].load_state_dict(ip_state_dict)
209
+ # ip_attn_procs[name].to(accelerator.device, dtype=torch.bfloat16)
210
+ # else:
211
+ # ip_attn_procs[name] = dit.attn_processors[name]
212
+ # dit.set_attn_processor(ip_attn_procs)
213
+
214
+
215
+ vae.requires_grad_(False)
216
+ t5.requires_grad_(False)
217
+ clip.requires_grad_(False)
218
+
219
+
220
+
221
+ # weight_dtype = torch.float32
222
+ # if accelerator.mixed_precision == "fp16":
223
+ # weight_dtype = torch.float16
224
+ # args.mixed_precision = accelerator.mixed_precision
225
+ # elif accelerator.mixed_precision == "bf16":
226
+ # weight_dtype = torch.bfloat16
227
+ # args.mixed_precision = accelerator.mixed_precision
228
+
229
+
230
+ # print(f"Resuming from checkpoint {args.ckpt_dir}")
231
+ # dit_stat_dict = load_file(args.ckpt_dir)
232
+ # Get path from Hub
233
+ model_path = hf_hub_download(
234
+ repo_id="Boese0601/ByteMorpher",
235
+ filename="dit.safetensors"
236
+ )
237
+ state_dict = load_file(model_path)
238
+ dit.load_state_dict(state_dict)
239
+ dit = dit.to(weight_dtype)
240
+ dit.eval()
241
+
242
+ # test_dataloader = loader(**args.data_config)
243
+ test_dataloader = eval_image_pair_loader(**args.data_config)
244
+
245
+
246
+
247
+ # from deepspeed import initialize
248
+ dit = accelerator.prepare(dit)
249
+
250
+ # if accelerator.is_main_process:
251
+ # accelerator.init_trackers(args.tracker_project_name, {"test": None})
252
+
253
+ # logger.info("***** Running Evaluation *****")
254
+ # logger.info(f" Instantaneous batch size = {args.eval_batch_size}")
255
+
256
+
257
+
258
+ # progress_bar = tqdm(
259
+ # range(0, len(test_dataloader)),
260
+ # initial=0,
261
+ # desc="Steps",
262
+ # disable=not accelerator.is_local_main_process,
263
+ # )
264
+
265
+ # for step, batch in enumerate(test_dataloader):
266
+ # with accelerator.accumulate(dit):
267
+ # img, tgt_image, prompt, edit_prompt, img_name, edit_name = batch
268
+ img = image_resize(image, 512)
269
+ w, h = img.size
270
+ new_w = (w // 32) * 32
271
+ new_h = (h // 32) * 32
272
+ img = img.resize((new_w, new_h))
273
+ img = torch.from_numpy((np.array(img) / 127.5) - 1)
274
+ img = img.permute(2, 0, 1).unsqueeze(0)
275
+
276
+ edit_prompt = edit_prompt
277
+
278
+ # if args.use_ip:
279
+ # sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=True, spatial_condition=False, clip_image_processor=ip_clip_image_processor, image_encoder=ip_image_encoder, improj=ip_improj)
280
+ # elif args.use_spatial_condition:
281
+ # sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=False, spatial_condition=True, clip_image_processor=None, image_encoder=None, improj=None,share_position_embedding=args.share_position_embedding)
282
+ # else:
283
+ # sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=False, spatial_condition=False, clip_image_processor=None, image_encoder=None, improj=None)
284
+ with torch.no_grad():
285
+ result = sampler(prompt=edit_prompt,
286
+ width=args.sample_width,
287
+ height=args.sample_height,
288
+ num_steps=args.sample_steps,
289
+ image_prompt=None, # ip_adapter
290
+ true_gs=args.cfg_scale,
291
+ seed=args.seed,
292
+ ip_scale=args.ip_scale if args.use_ip else 1.0,
293
+ source_image=img if args.use_spatial_condition else None,
294
+ )
295
+ gen_img = result
296
+
297
+
298
+
299
+ # progress_bar.update(1)
300
+
301
+ # accelerator.wait_for_everyone()
302
+ # accelerator.end_training()
303
+ return gen_img
304
+
305
+
306
+ def get_samples():
307
+ sample_list = [
308
+ {
309
+ "image": "assets/0_camera_zoom/20486354.png",
310
+ "edit_prompt": "Zoom in on the coral and add a small blue fish in the background.",
311
+ },
312
+ ]
313
+ return [
314
+ [
315
+ Image.open(sample["image"]).resize((512, 512)),
316
+ sample["edit_prompt"],
317
+ ]
318
+ for sample in sample_list
319
+ ]
320
+
321
+
322
+ header = """
323
+ # ByteMoprh
324
+
325
+ <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
326
+ <a href=""><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
327
+ <a href="https://huggingface.co/datasets/Boese0601/ByteMorph-Bench"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
328
+ <a href="https://github.com/Boese0601/ByteMorph"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
329
+ </div>
330
+ """
331
+
332
+
333
+ def create_app():
334
+ with gr.Blocks() as app:
335
+ gr.Markdown(header, elem_id="header")
336
+ with gr.Row(equal_height=False):
337
+ with gr.Column(variant="panel", elem_classes="inputPanel"):
338
+ original_image = gr.Image(
339
+ type="pil", label="Condition Image", width=300, elem_id="input"
340
+ )
341
+ edit_prompt = gr.Textbox(lines=2, label="Edit Prompt", elem_id="edit_prompt")
342
+ submit_btn = gr.Button("Run", elem_id="submit_btn")
343
+
344
+ with gr.Column(variant="panel", elem_classes="outputPanel"):
345
+ output_image = gr.Image(type="pil", elem_id="output")
346
+
347
+ with gr.Row():
348
+ examples = gr.Examples(
349
+ examples=get_samples(),
350
+ inputs=[original_image, edit_prompt],
351
+ label="Examples",
352
+ )
353
+
354
+ submit_btn.click(
355
+ fn=generate,
356
+ inputs=[original_image, edit_prompt],
357
+ outputs=output_image,
358
+ )
359
+ gr.HTML(
360
+ """
361
+ <div style="text-align: center;">
362
+ * This demo's template was modified from <a href="https://arxiv.org/abs/2411.15098" target="_blank">OminiControl</a>.
363
+ </div>
364
+ """
365
+ )
366
+ return app
367
+
368
+
369
+ if __name__ == "__main__":
370
+ print("CUDA available:", torch.cuda.is_available())
371
+ print("CUDA version:", torch.version.cuda)
372
+ print("GPU device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")
373
+ # mp.set_start_method("spawn", force=True)
374
+ create_app().launch(debug=False, share=True, ssr_mode=False)
image_datasets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
image_datasets/dataset.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from torch.utils.data import Dataset, DataLoader
7
+ import json
8
+ import random
9
+ import glob
10
+ import torch
11
+ import torchvision.transforms.functional as TF
12
+
13
+
14
+
15
+ def image_resize(img, max_size=512):
16
+ w, h = img.size
17
+ if w >= h:
18
+ new_w = max_size
19
+ new_h = int((max_size / w) * h)
20
+ else:
21
+ new_h = max_size
22
+ new_w = int((max_size / h) * w)
23
+ return img.resize((new_w, new_h))
24
+
25
+ def c_crop(image):
26
+ width, height = image.size
27
+ new_size = min(width, height)
28
+ left = (width - new_size) / 2
29
+ top = (height - new_size) / 2
30
+ right = (width + new_size) / 2
31
+ bottom = (height + new_size) / 2
32
+ return image.crop((left, top, right, bottom))
33
+
34
+ def crop_to_aspect_ratio(image, ratio="16:9"):
35
+ width, height = image.size
36
+ ratio_map = {
37
+ "16:9": (16, 9),
38
+ "4:3": (4, 3),
39
+ "1:1": (1, 1)
40
+ }
41
+ target_w, target_h = ratio_map[ratio]
42
+ target_ratio_value = target_w / target_h
43
+
44
+ current_ratio = width / height
45
+
46
+ if current_ratio > target_ratio_value:
47
+ new_width = int(height * target_ratio_value)
48
+ offset = (width - new_width) // 2
49
+ crop_box = (offset, 0, offset + new_width, height)
50
+ else:
51
+ new_height = int(width / target_ratio_value)
52
+ offset = (height - new_height) // 2
53
+ crop_box = (0, offset, width, offset + new_height)
54
+
55
+ cropped_img = image.crop(crop_box)
56
+ return cropped_img
57
+
58
+
59
+ class CustomImageDataset(Dataset):
60
+ def __init__(self, img_dir, img_size=512, caption_type='json', random_ratio=False):
61
+ self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
62
+ # self.images = glob.glob(img_dir +'**/*.jpg', recursive=True) + glob.glob(img_dir +'**/*.png', recursive=True) + glob.glob(img_dir +'**/*.jpeg', recursive=True)
63
+ self.images.sort()
64
+ self.img_size = img_size
65
+ self.caption_type = caption_type
66
+ self.random_ratio = random_ratio
67
+
68
+ def __len__(self):
69
+ return len(self.images)
70
+
71
+ def __getitem__(self, idx):
72
+ try:
73
+ img = Image.open(self.images[idx]).convert('RGB')
74
+
75
+ if self.random_ratio:
76
+ ratio = random.choice(["16:9", "default", "1:1", "4:3"])
77
+ if ratio != "default":
78
+ img = crop_to_aspect_ratio(img, ratio)
79
+ img = image_resize(img, self.img_size)
80
+ w, h = img.size
81
+ new_w = (w // 32) * 32
82
+ new_h = (h // 32) * 32
83
+ img = img.resize((new_w, new_h))
84
+ img = torch.from_numpy((np.array(img) / 127.5) - 1)
85
+ img = img.permute(2, 0, 1)
86
+ json_path = self.images[idx].split('.')[0] + '.' + self.caption_type
87
+ if self.caption_type == "json":
88
+ prompt = json.load(open(json_path))['caption']
89
+ else:
90
+ prompt = open(json_path).read()
91
+ return img, prompt
92
+ except Exception as e:
93
+ print(e)
94
+ return self.__getitem__(random.randint(0, len(self.images) - 1))
95
+
96
+
97
+ def loader(train_batch_size, num_workers, **args):
98
+ dataset = CustomImageDataset(**args)
99
+ return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True)
100
+
101
+
102
+
103
+ class ImageEditPairDataset(Dataset):
104
+ def __init__(self, img_dir, img_size=512, caption_type='json', random_ratio=False, grayscale_editing=False, zoom_camera=False):
105
+ # self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
106
+ self.images = glob.glob(img_dir +'**/*.jpg', recursive=True) + glob.glob(img_dir +'**/*.png', recursive=True) + glob.glob(img_dir +'**/*.jpeg', recursive=True)
107
+ self.images.sort()
108
+ self.img_size = img_size
109
+ self.caption_type = caption_type
110
+ self.random_ratio = random_ratio
111
+ self.grayscale_editing = grayscale_editing
112
+ self.zoom_camera = zoom_camera
113
+ if "ByteMorph-Bench" or "InstructMove" in img_dir:
114
+ self.eval = True
115
+ else:
116
+ self.eval = False
117
+ def __len__(self):
118
+ return len(self.images)
119
+
120
+ def __getitem__(self, idx):
121
+ try:
122
+ img = Image.open(self.images[idx]).convert('RGB')
123
+ ori_width, ori_height = img.size
124
+ left_half = (0, 0, ori_width // 2, ori_height)
125
+ right_half = (ori_width // 2, 0, ori_width, ori_height)
126
+ src_image = img.crop(left_half) # Left half
127
+ tgt_image = img.crop(right_half) # Right half
128
+ # print("ori_width, ori_height: ",ori_width, ori_height)
129
+ if self.random_ratio:
130
+ ratio = random.choice(["16:9", "default", "1:1", "4:3"])
131
+ if ratio != "default":
132
+ src_image = crop_to_aspect_ratio(src_image, ratio)
133
+ tgt_image = crop_to_aspect_ratio(tgt_image, ratio)
134
+ src_image = image_resize(src_image, self.img_size)
135
+ tgt_image = image_resize(tgt_image, self.img_size)
136
+ w, h = src_image.size
137
+ new_w = (w // 32) * 32
138
+ new_h = (h // 32) * 32
139
+ # print("new_w, new_h: ",new_w, new_h)
140
+ src_image = src_image.resize((new_w, new_h))
141
+ src_image = torch.from_numpy((np.array(src_image) / 127.5) - 1)
142
+ src_image = src_image.permute(2, 0, 1)
143
+ tgt_image = tgt_image.resize((new_w, new_h))
144
+ tgt_image = torch.from_numpy((np.array(tgt_image) / 127.5) - 1)
145
+ tgt_image = tgt_image.permute(2, 0, 1)
146
+ json_path = self.images[idx].split('.')[0] + '.' + self.caption_type
147
+ if self.eval:
148
+ image_name = self.images[idx].split('.')[0].split("/")[-1]
149
+ edit_type = self.images[idx].split('.')[0].split("/")[-2]
150
+ if self.caption_type == "json":
151
+ if not self.eval:
152
+ prompt = json.load(open(json_path))['caption']
153
+ edit_prompt = json.load(open(json_path))['edit']
154
+ else:
155
+ prompt = [] #json.load(open(json_path))['caption']
156
+ edit_prompt = json.load(open(json_path))['edit']
157
+ else:
158
+ raise NotImplementedError
159
+ # prompt = open(json_path).read()
160
+ if (not self.grayscale_editing) and (not self.zoom_camera):
161
+ if not self.eval:
162
+ return src_image, tgt_image, prompt, edit_prompt
163
+ else:
164
+ return src_image, tgt_image, prompt, edit_prompt, image_name, edit_type
165
+ if self.grayscale_editing and (not self.zoom_camera):
166
+ # Grayscale = 0.2989 * R + 0.5870 * G + 0.1140 * B
167
+ grayscale_image = 0.2989 * src_image[0, :, :] + 0.5870 * src_image[1, :, :] + 0.1140 * src_image[2, :, :]
168
+ tgt_image = grayscale_image.unsqueeze(0).repeat(3, 1, 1)
169
+ edit_prompt = "Convert the input image to a black and white grayscale image while maintaining the original composition and details."
170
+ if not self.eval:
171
+ return src_image, tgt_image, prompt, edit_prompt
172
+ else:
173
+ return src_image, tgt_image, prompt, edit_prompt, image_name, edit_type
174
+ if (not self.grayscale_editing) and self.zoom_camera:
175
+ cropped = TF.center_crop(src_image, (256, 256))
176
+ tgt_image = TF.resize(cropped, (512, 512))
177
+ edit_prompt = "The central area of the input image is zoomed. The camera transitions from a wide shot to a closer position, narrowing its view."
178
+ if not self.eval:
179
+ return src_image, tgt_image, prompt, edit_prompt
180
+ else:
181
+ return src_image, tgt_image, prompt, edit_prompt, image_name, edit_type
182
+ if self.grayscale_editing and self.zoom_camera:
183
+ grayscale_image = 0.2989 * src_image[0, :, :] + 0.5870 * src_image[1, :, :] + 0.1140 * src_image[2, :, :]
184
+ tgt_image = grayscale_image.unsqueeze(0).repeat(3, 1, 1)
185
+ tgt_image = TF.center_crop(tgt_image, (256, 256))
186
+ tgt_image = TF.resize(tgt_image, (512, 512))
187
+ edit_prompt = "Convert the input image to a black and white grayscale image while maintaining the original composition and details. And the central area of the input image is zoomed, the camera transitions from a wide shot to a closer position, narrowing its view."
188
+ if not self.eval:
189
+ return src_image, tgt_image, prompt, edit_prompt
190
+ else:
191
+ return src_image, tgt_image, prompt, edit_prompt, image_name, edit_type
192
+ except Exception as e:
193
+ print(e)
194
+ return self.__getitem__(random.randint(0, len(self.images) - 1))
195
+
196
+
197
+ def image_pair_loader(train_batch_size, num_workers, **args):
198
+ dataset = ImageEditPairDataset(**args)
199
+ return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True)
200
+
201
+ def eval_image_pair_loader(eval_batch_size, num_workers, **args):
202
+ dataset = ImageEditPairDataset(**args)
203
+ return DataLoader(dataset, batch_size=eval_batch_size, num_workers=num_workers, shuffle=False)
204
+
205
+
206
+
207
+ if __name__ == "__main__":
208
+ from src.flux.util import save_image
209
+ example_dataset = ImageEditPairDataset(
210
+ img_dir="",
211
+ img_size=512,
212
+ caption_type='json',
213
+ random_ratio=False,
214
+ grayscale_editing=False,
215
+ zoom_camera=False,
216
+ )
217
+
218
+ train_dataloader = DataLoader(
219
+ example_dataset,
220
+ batch_size=1,
221
+ num_workers=4,
222
+ shuffle=False,
223
+ )
224
+
225
+ for step, batch in enumerate(train_dataloader):
226
+ src_image, tgt_image, prompt, edit_prompt = batch
227
+ os.makedirs("./debug", exist_ok=True)
228
+ save_image(src_image, f"./debug/{step}-src_img.jpg")
229
+ save_image(tgt_image, f"./debug/{step}-tgt_img.jpg")
230
+ if step == 3:
231
+ breakpoint()
inference_configs/inference.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "flux-dev"
2
+ use_spatial_condition: true
3
+ share_position_embedding: true
4
+ use_share_weight_referencenet: false
5
+ use_ip: false
6
+ ip_local_path: null
7
+ ip_repo_id: null
8
+ ip_name: null
9
+ ip_scale: 1.0
10
+ use_lora: false
11
+ data_config:
12
+ eval_batch_size: 1
13
+ num_workers: 0
14
+ img_size: 512
15
+ img_dir: output_bench/ #./ByteMorph-Bench/
16
+ grayscale_editing: false
17
+ zoom_camera: false
18
+ random_ratio: false
19
+ report_to: wandb
20
+ eval_batch_size: 1
21
+ ckpt_dir: ./pretrained_weights/ByteMorpher/dit.safetensors
22
+ output_dir: ./test_log/seedmorpher/
23
+ logging_dir: logs
24
+ mixed_precision: "bf16"
25
+ rank: 16
26
+ single_blocks: null
27
+ double_blocks: null
28
+ disable_sampling: false
29
+ sample_width: 512
30
+ sample_height: 512
31
+ sample_steps: 25
32
+ seed: 42
33
+ cfg_scale: 3.5
requirements.txt CHANGED
@@ -1,6 +1,26 @@
1
- accelerate
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
 
 
 
 
 
1
+ # --extra-index-url https://download.pytorch.org/whl/cu124
2
+ # torch==2.6.0
3
+ # torchvision==0.21.0
4
+ # torchaudio==2.6.0
5
+
6
+ gradio>=4.0
7
+ accelerate==0.30.1
8
+ deepspeed==0.14.4
9
+ einops==0.8.0
10
+ transformers==4.43.3
11
+ huggingface-hub==0.24.5
12
+ optimum-quanto
13
+ datasets
14
+ omegaconf
15
  diffusers
16
+ sentencepiece
17
+ opencv-python
18
+ matplotlib
19
+ onnxruntime
20
+ timm
21
+ wandb
22
+
23
+ setuptools
24
+ wheel
25
+
26
+
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/flux/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/flux/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from ._version import version as __version__ # type: ignore
3
+ from ._version import version_tuple
4
+ except ImportError:
5
+ __version__ = "unknown (no version information available)"
6
+ version_tuple = (0, 0, "unknown", "noinfo")
7
+
8
+ from pathlib import Path
9
+
10
+ PACKAGE = __package__.replace("_", "-")
11
+ PACKAGE_ROOT = Path(__file__).parent
src/flux/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .cli import app
2
+
3
+ if __name__ == "__main__":
4
+ app()
src/flux/annotator/canny/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+
4
+ class CannyDetector:
5
+ def __call__(self, img, low_threshold, high_threshold):
6
+ return cv2.Canny(img, low_threshold, high_threshold)
src/flux/annotator/ckpts/ckpts.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Weights here.
src/flux/annotator/dwpose/__init__.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Openpose
2
+ # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
3
+ # 2nd Edited by https://github.com/Hzzone/pytorch-openpose
4
+ # 3rd Edited by ControlNet
5
+ # 4th Edited by ControlNet (added face and correct hands)
6
+
7
+ import os
8
+ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
9
+
10
+ import torch
11
+ import numpy as np
12
+ from . import util
13
+ from .wholebody import Wholebody
14
+
15
+ def draw_pose(pose, H, W):
16
+ bodies = pose['bodies']
17
+ faces = pose['faces']
18
+ hands = pose['hands']
19
+ candidate = bodies['candidate']
20
+ subset = bodies['subset']
21
+ canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
22
+
23
+ canvas = util.draw_bodypose(canvas, candidate, subset)
24
+
25
+ canvas = util.draw_handpose(canvas, hands)
26
+
27
+ canvas = util.draw_facepose(canvas, faces)
28
+
29
+ return canvas
30
+
31
+
32
+ class DWposeDetector:
33
+ def __init__(self, device):
34
+
35
+ self.pose_estimation = Wholebody(device)
36
+
37
+ def __call__(self, oriImg):
38
+ oriImg = oriImg.copy()
39
+ H, W, C = oriImg.shape
40
+ with torch.no_grad():
41
+ candidate, subset = self.pose_estimation(oriImg)
42
+ nums, keys, locs = candidate.shape
43
+ candidate[..., 0] /= float(W)
44
+ candidate[..., 1] /= float(H)
45
+ body = candidate[:,:18].copy()
46
+ body = body.reshape(nums*18, locs)
47
+ score = subset[:,:18]
48
+ for i in range(len(score)):
49
+ for j in range(len(score[i])):
50
+ if score[i][j] > 0.3:
51
+ score[i][j] = int(18*i+j)
52
+ else:
53
+ score[i][j] = -1
54
+
55
+ un_visible = subset<0.3
56
+ candidate[un_visible] = -1
57
+
58
+ foot = candidate[:,18:24]
59
+
60
+ faces = candidate[:,24:92]
61
+
62
+ hands = candidate[:,92:113]
63
+ hands = np.vstack([hands, candidate[:,113:]])
64
+
65
+ bodies = dict(candidate=body, subset=score)
66
+ pose = dict(bodies=bodies, hands=hands, faces=faces)
67
+
68
+ return draw_pose(pose, H, W)
src/flux/annotator/dwpose/onnxdet.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ import onnxruntime
5
+
6
+ def nms(boxes, scores, nms_thr):
7
+ """Single class NMS implemented in Numpy."""
8
+ x1 = boxes[:, 0]
9
+ y1 = boxes[:, 1]
10
+ x2 = boxes[:, 2]
11
+ y2 = boxes[:, 3]
12
+
13
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
14
+ order = scores.argsort()[::-1]
15
+
16
+ keep = []
17
+ while order.size > 0:
18
+ i = order[0]
19
+ keep.append(i)
20
+ xx1 = np.maximum(x1[i], x1[order[1:]])
21
+ yy1 = np.maximum(y1[i], y1[order[1:]])
22
+ xx2 = np.minimum(x2[i], x2[order[1:]])
23
+ yy2 = np.minimum(y2[i], y2[order[1:]])
24
+
25
+ w = np.maximum(0.0, xx2 - xx1 + 1)
26
+ h = np.maximum(0.0, yy2 - yy1 + 1)
27
+ inter = w * h
28
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
29
+
30
+ inds = np.where(ovr <= nms_thr)[0]
31
+ order = order[inds + 1]
32
+
33
+ return keep
34
+
35
+ def multiclass_nms(boxes, scores, nms_thr, score_thr):
36
+ """Multiclass NMS implemented in Numpy. Class-aware version."""
37
+ final_dets = []
38
+ num_classes = scores.shape[1]
39
+ for cls_ind in range(num_classes):
40
+ cls_scores = scores[:, cls_ind]
41
+ valid_score_mask = cls_scores > score_thr
42
+ if valid_score_mask.sum() == 0:
43
+ continue
44
+ else:
45
+ valid_scores = cls_scores[valid_score_mask]
46
+ valid_boxes = boxes[valid_score_mask]
47
+ keep = nms(valid_boxes, valid_scores, nms_thr)
48
+ if len(keep) > 0:
49
+ cls_inds = np.ones((len(keep), 1)) * cls_ind
50
+ dets = np.concatenate(
51
+ [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
52
+ )
53
+ final_dets.append(dets)
54
+ if len(final_dets) == 0:
55
+ return None
56
+ return np.concatenate(final_dets, 0)
57
+
58
+ def demo_postprocess(outputs, img_size, p6=False):
59
+ grids = []
60
+ expanded_strides = []
61
+ strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
62
+
63
+ hsizes = [img_size[0] // stride for stride in strides]
64
+ wsizes = [img_size[1] // stride for stride in strides]
65
+
66
+ for hsize, wsize, stride in zip(hsizes, wsizes, strides):
67
+ xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
68
+ grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
69
+ grids.append(grid)
70
+ shape = grid.shape[:2]
71
+ expanded_strides.append(np.full((*shape, 1), stride))
72
+
73
+ grids = np.concatenate(grids, 1)
74
+ expanded_strides = np.concatenate(expanded_strides, 1)
75
+ outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
76
+ outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
77
+
78
+ return outputs
79
+
80
+ def preprocess(img, input_size, swap=(2, 0, 1)):
81
+ if len(img.shape) == 3:
82
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
83
+ else:
84
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
85
+
86
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
87
+ resized_img = cv2.resize(
88
+ img,
89
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
90
+ interpolation=cv2.INTER_LINEAR,
91
+ ).astype(np.uint8)
92
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
93
+
94
+ padded_img = padded_img.transpose(swap)
95
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
96
+ return padded_img, r
97
+
98
+ def inference_detector(session, oriImg):
99
+ input_shape = (640,640)
100
+ img, ratio = preprocess(oriImg, input_shape)
101
+
102
+ ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
103
+ output = session.run(None, ort_inputs)
104
+ predictions = demo_postprocess(output[0], input_shape)[0]
105
+
106
+ boxes = predictions[:, :4]
107
+ scores = predictions[:, 4:5] * predictions[:, 5:]
108
+
109
+ boxes_xyxy = np.ones_like(boxes)
110
+ boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
111
+ boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
112
+ boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
113
+ boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
114
+ boxes_xyxy /= ratio
115
+ dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
116
+ if dets is not None:
117
+ final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
118
+ isscore = final_scores>0.3
119
+ iscat = final_cls_inds == 0
120
+ isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
121
+ final_boxes = final_boxes[isbbox]
122
+ else:
123
+ final_boxes = np.array([])
124
+
125
+ return final_boxes
src/flux/annotator/dwpose/onnxpose.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+
7
+ def preprocess(
8
+ img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
9
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
10
+ """Do preprocessing for RTMPose model inference.
11
+
12
+ Args:
13
+ img (np.ndarray): Input image in shape.
14
+ input_size (tuple): Input image size in shape (w, h).
15
+
16
+ Returns:
17
+ tuple:
18
+ - resized_img (np.ndarray): Preprocessed image.
19
+ - center (np.ndarray): Center of image.
20
+ - scale (np.ndarray): Scale of image.
21
+ """
22
+ # get shape of image
23
+ img_shape = img.shape[:2]
24
+ out_img, out_center, out_scale = [], [], []
25
+ if len(out_bbox) == 0:
26
+ out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
27
+ for i in range(len(out_bbox)):
28
+ x0 = out_bbox[i][0]
29
+ y0 = out_bbox[i][1]
30
+ x1 = out_bbox[i][2]
31
+ y1 = out_bbox[i][3]
32
+ bbox = np.array([x0, y0, x1, y1])
33
+
34
+ # get center and scale
35
+ center, scale = bbox_xyxy2cs(bbox, padding=1.25)
36
+
37
+ # do affine transformation
38
+ resized_img, scale = top_down_affine(input_size, scale, center, img)
39
+
40
+ # normalize image
41
+ mean = np.array([123.675, 116.28, 103.53])
42
+ std = np.array([58.395, 57.12, 57.375])
43
+ resized_img = (resized_img - mean) / std
44
+
45
+ out_img.append(resized_img)
46
+ out_center.append(center)
47
+ out_scale.append(scale)
48
+
49
+ return out_img, out_center, out_scale
50
+
51
+
52
+ def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
53
+ """Inference RTMPose model.
54
+
55
+ Args:
56
+ sess (ort.InferenceSession): ONNXRuntime session.
57
+ img (np.ndarray): Input image in shape.
58
+
59
+ Returns:
60
+ outputs (np.ndarray): Output of RTMPose model.
61
+ """
62
+ all_out = []
63
+ # build input
64
+ for i in range(len(img)):
65
+ input = [img[i].transpose(2, 0, 1)]
66
+
67
+ # build output
68
+ sess_input = {sess.get_inputs()[0].name: input}
69
+ sess_output = []
70
+ for out in sess.get_outputs():
71
+ sess_output.append(out.name)
72
+
73
+ # run model
74
+ outputs = sess.run(sess_output, sess_input)
75
+ all_out.append(outputs)
76
+
77
+ return all_out
78
+
79
+
80
+ def postprocess(outputs: List[np.ndarray],
81
+ model_input_size: Tuple[int, int],
82
+ center: Tuple[int, int],
83
+ scale: Tuple[int, int],
84
+ simcc_split_ratio: float = 2.0
85
+ ) -> Tuple[np.ndarray, np.ndarray]:
86
+ """Postprocess for RTMPose model output.
87
+
88
+ Args:
89
+ outputs (np.ndarray): Output of RTMPose model.
90
+ model_input_size (tuple): RTMPose model Input image size.
91
+ center (tuple): Center of bbox in shape (x, y).
92
+ scale (tuple): Scale of bbox in shape (w, h).
93
+ simcc_split_ratio (float): Split ratio of simcc.
94
+
95
+ Returns:
96
+ tuple:
97
+ - keypoints (np.ndarray): Rescaled keypoints.
98
+ - scores (np.ndarray): Model predict scores.
99
+ """
100
+ all_key = []
101
+ all_score = []
102
+ for i in range(len(outputs)):
103
+ # use simcc to decode
104
+ simcc_x, simcc_y = outputs[i]
105
+ keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
106
+
107
+ # rescale keypoints
108
+ keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
109
+ all_key.append(keypoints[0])
110
+ all_score.append(scores[0])
111
+
112
+ return np.array(all_key), np.array(all_score)
113
+
114
+
115
+ def bbox_xyxy2cs(bbox: np.ndarray,
116
+ padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
117
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
118
+
119
+ Args:
120
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
121
+ as (left, top, right, bottom)
122
+ padding (float): BBox padding factor that will be multilied to scale.
123
+ Default: 1.0
124
+
125
+ Returns:
126
+ tuple: A tuple containing center and scale.
127
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
128
+ (n, 2)
129
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
130
+ (n, 2)
131
+ """
132
+ # convert single bbox from (4, ) to (1, 4)
133
+ dim = bbox.ndim
134
+ if dim == 1:
135
+ bbox = bbox[None, :]
136
+
137
+ # get bbox center and scale
138
+ x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
139
+ center = np.hstack([x1 + x2, y1 + y2]) * 0.5
140
+ scale = np.hstack([x2 - x1, y2 - y1]) * padding
141
+
142
+ if dim == 1:
143
+ center = center[0]
144
+ scale = scale[0]
145
+
146
+ return center, scale
147
+
148
+
149
+ def _fix_aspect_ratio(bbox_scale: np.ndarray,
150
+ aspect_ratio: float) -> np.ndarray:
151
+ """Extend the scale to match the given aspect ratio.
152
+
153
+ Args:
154
+ scale (np.ndarray): The image scale (w, h) in shape (2, )
155
+ aspect_ratio (float): The ratio of ``w/h``
156
+
157
+ Returns:
158
+ np.ndarray: The reshaped image scale in (2, )
159
+ """
160
+ w, h = np.hsplit(bbox_scale, [1])
161
+ bbox_scale = np.where(w > h * aspect_ratio,
162
+ np.hstack([w, w / aspect_ratio]),
163
+ np.hstack([h * aspect_ratio, h]))
164
+ return bbox_scale
165
+
166
+
167
+ def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
168
+ """Rotate a point by an angle.
169
+
170
+ Args:
171
+ pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
172
+ angle_rad (float): rotation angle in radian
173
+
174
+ Returns:
175
+ np.ndarray: Rotated point in shape (2, )
176
+ """
177
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
178
+ rot_mat = np.array([[cs, -sn], [sn, cs]])
179
+ return rot_mat @ pt
180
+
181
+
182
+ def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
183
+ """To calculate the affine matrix, three pairs of points are required. This
184
+ function is used to get the 3rd point, given 2D points a & b.
185
+
186
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
187
+ anticlockwise, using b as the rotation center.
188
+
189
+ Args:
190
+ a (np.ndarray): The 1st point (x,y) in shape (2, )
191
+ b (np.ndarray): The 2nd point (x,y) in shape (2, )
192
+
193
+ Returns:
194
+ np.ndarray: The 3rd point.
195
+ """
196
+ direction = a - b
197
+ c = b + np.r_[-direction[1], direction[0]]
198
+ return c
199
+
200
+
201
+ def get_warp_matrix(center: np.ndarray,
202
+ scale: np.ndarray,
203
+ rot: float,
204
+ output_size: Tuple[int, int],
205
+ shift: Tuple[float, float] = (0., 0.),
206
+ inv: bool = False) -> np.ndarray:
207
+ """Calculate the affine transformation matrix that can warp the bbox area
208
+ in the input image to the output size.
209
+
210
+ Args:
211
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
212
+ scale (np.ndarray[2, ]): Scale of the bounding box
213
+ wrt [width, height].
214
+ rot (float): Rotation angle (degree).
215
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
216
+ destination heatmaps.
217
+ shift (0-100%): Shift translation ratio wrt the width/height.
218
+ Default (0., 0.).
219
+ inv (bool): Option to inverse the affine transform direction.
220
+ (inv=False: src->dst or inv=True: dst->src)
221
+
222
+ Returns:
223
+ np.ndarray: A 2x3 transformation matrix
224
+ """
225
+ shift = np.array(shift)
226
+ src_w = scale[0]
227
+ dst_w = output_size[0]
228
+ dst_h = output_size[1]
229
+
230
+ # compute transformation matrix
231
+ rot_rad = np.deg2rad(rot)
232
+ src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
233
+ dst_dir = np.array([0., dst_w * -0.5])
234
+
235
+ # get four corners of the src rectangle in the original image
236
+ src = np.zeros((3, 2), dtype=np.float32)
237
+ src[0, :] = center + scale * shift
238
+ src[1, :] = center + src_dir + scale * shift
239
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
240
+
241
+ # get four corners of the dst rectangle in the input image
242
+ dst = np.zeros((3, 2), dtype=np.float32)
243
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
244
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
245
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
246
+
247
+ if inv:
248
+ warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
249
+ else:
250
+ warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
251
+
252
+ return warp_mat
253
+
254
+
255
+ def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
256
+ img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
257
+ """Get the bbox image as the model input by affine transform.
258
+
259
+ Args:
260
+ input_size (dict): The input size of the model.
261
+ bbox_scale (dict): The bbox scale of the img.
262
+ bbox_center (dict): The bbox center of the img.
263
+ img (np.ndarray): The original image.
264
+
265
+ Returns:
266
+ tuple: A tuple containing center and scale.
267
+ - np.ndarray[float32]: img after affine transform.
268
+ - np.ndarray[float32]: bbox scale after affine transform.
269
+ """
270
+ w, h = input_size
271
+ warp_size = (int(w), int(h))
272
+
273
+ # reshape bbox to fixed aspect ratio
274
+ bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
275
+
276
+ # get the affine matrix
277
+ center = bbox_center
278
+ scale = bbox_scale
279
+ rot = 0
280
+ warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
281
+
282
+ # do affine transform
283
+ img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
284
+
285
+ return img, bbox_scale
286
+
287
+
288
+ def get_simcc_maximum(simcc_x: np.ndarray,
289
+ simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
290
+ """Get maximum response location and value from simcc representations.
291
+
292
+ Note:
293
+ instance number: N
294
+ num_keypoints: K
295
+ heatmap height: H
296
+ heatmap width: W
297
+
298
+ Args:
299
+ simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
300
+ simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
301
+
302
+ Returns:
303
+ tuple:
304
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
305
+ (K, 2) or (N, K, 2)
306
+ - vals (np.ndarray): values of maximum heatmap responses in shape
307
+ (K,) or (N, K)
308
+ """
309
+ N, K, Wx = simcc_x.shape
310
+ simcc_x = simcc_x.reshape(N * K, -1)
311
+ simcc_y = simcc_y.reshape(N * K, -1)
312
+
313
+ # get maximum value locations
314
+ x_locs = np.argmax(simcc_x, axis=1)
315
+ y_locs = np.argmax(simcc_y, axis=1)
316
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
317
+ max_val_x = np.amax(simcc_x, axis=1)
318
+ max_val_y = np.amax(simcc_y, axis=1)
319
+
320
+ # get maximum value across x and y axis
321
+ mask = max_val_x > max_val_y
322
+ max_val_x[mask] = max_val_y[mask]
323
+ vals = max_val_x
324
+ locs[vals <= 0.] = -1
325
+
326
+ # reshape
327
+ locs = locs.reshape(N, K, 2)
328
+ vals = vals.reshape(N, K)
329
+
330
+ return locs, vals
331
+
332
+
333
+ def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
334
+ simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
335
+ """Modulate simcc distribution with Gaussian.
336
+
337
+ Args:
338
+ simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
339
+ simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
340
+ simcc_split_ratio (int): The split ratio of simcc.
341
+
342
+ Returns:
343
+ tuple: A tuple containing center and scale.
344
+ - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
345
+ - np.ndarray[float32]: scores in shape (K,) or (n, K)
346
+ """
347
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
348
+ keypoints /= simcc_split_ratio
349
+
350
+ return keypoints, scores
351
+
352
+
353
+ def inference_pose(session, out_bbox, oriImg):
354
+ h, w = session.get_inputs()[0].shape[2:]
355
+ model_input_size = (w, h)
356
+ resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
357
+ outputs = inference(session, resized_img)
358
+ keypoints, scores = postprocess(outputs, model_input_size, center, scale)
359
+
360
+ return keypoints, scores
src/flux/annotator/dwpose/util.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import matplotlib
4
+ import cv2
5
+
6
+
7
+ eps = 0.01
8
+
9
+
10
+ def smart_resize(x, s):
11
+ Ht, Wt = s
12
+ if x.ndim == 2:
13
+ Ho, Wo = x.shape
14
+ Co = 1
15
+ else:
16
+ Ho, Wo, Co = x.shape
17
+ if Co == 3 or Co == 1:
18
+ k = float(Ht + Wt) / float(Ho + Wo)
19
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
20
+ else:
21
+ return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
22
+
23
+
24
+ def smart_resize_k(x, fx, fy):
25
+ if x.ndim == 2:
26
+ Ho, Wo = x.shape
27
+ Co = 1
28
+ else:
29
+ Ho, Wo, Co = x.shape
30
+ Ht, Wt = Ho * fy, Wo * fx
31
+ if Co == 3 or Co == 1:
32
+ k = float(Ht + Wt) / float(Ho + Wo)
33
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
34
+ else:
35
+ return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
36
+
37
+
38
+ def padRightDownCorner(img, stride, padValue):
39
+ h = img.shape[0]
40
+ w = img.shape[1]
41
+
42
+ pad = 4 * [None]
43
+ pad[0] = 0 # up
44
+ pad[1] = 0 # left
45
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
46
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
47
+
48
+ img_padded = img
49
+ pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
50
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
51
+ pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
52
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
53
+ pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
54
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
55
+ pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
56
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
57
+
58
+ return img_padded, pad
59
+
60
+
61
+ def transfer(model, model_weights):
62
+ transfered_model_weights = {}
63
+ for weights_name in model.state_dict().keys():
64
+ transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
65
+ return transfered_model_weights
66
+
67
+
68
+ def draw_bodypose(canvas, candidate, subset):
69
+ H, W, C = canvas.shape
70
+ candidate = np.array(candidate)
71
+ subset = np.array(subset)
72
+
73
+ stickwidth = 4
74
+
75
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
76
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
77
+ [1, 16], [16, 18], [3, 17], [6, 18]]
78
+
79
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
80
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
81
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
82
+
83
+ for i in range(17):
84
+ for n in range(len(subset)):
85
+ index = subset[n][np.array(limbSeq[i]) - 1]
86
+ if -1 in index:
87
+ continue
88
+ Y = candidate[index.astype(int), 0] * float(W)
89
+ X = candidate[index.astype(int), 1] * float(H)
90
+ mX = np.mean(X)
91
+ mY = np.mean(Y)
92
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
93
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
94
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
95
+ cv2.fillConvexPoly(canvas, polygon, colors[i])
96
+
97
+ canvas = (canvas * 0.6).astype(np.uint8)
98
+
99
+ for i in range(18):
100
+ for n in range(len(subset)):
101
+ index = int(subset[n][i])
102
+ if index == -1:
103
+ continue
104
+ x, y = candidate[index][0:2]
105
+ x = int(x * W)
106
+ y = int(y * H)
107
+ cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
108
+
109
+ return canvas
110
+
111
+
112
+ def draw_handpose(canvas, all_hand_peaks):
113
+ H, W, C = canvas.shape
114
+
115
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
116
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
117
+
118
+ for peaks in all_hand_peaks:
119
+ peaks = np.array(peaks)
120
+
121
+ for ie, e in enumerate(edges):
122
+ x1, y1 = peaks[e[0]]
123
+ x2, y2 = peaks[e[1]]
124
+ x1 = int(x1 * W)
125
+ y1 = int(y1 * H)
126
+ x2 = int(x2 * W)
127
+ y2 = int(y2 * H)
128
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
129
+ cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
130
+
131
+ for i, keyponit in enumerate(peaks):
132
+ x, y = keyponit
133
+ x = int(x * W)
134
+ y = int(y * H)
135
+ if x > eps and y > eps:
136
+ cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
137
+ return canvas
138
+
139
+
140
+ def draw_facepose(canvas, all_lmks):
141
+ H, W, C = canvas.shape
142
+ for lmks in all_lmks:
143
+ lmks = np.array(lmks)
144
+ for lmk in lmks:
145
+ x, y = lmk
146
+ x = int(x * W)
147
+ y = int(y * H)
148
+ if x > eps and y > eps:
149
+ cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
150
+ return canvas
151
+
152
+
153
+ # detect hand according to body pose keypoints
154
+ # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
155
+ def handDetect(candidate, subset, oriImg):
156
+ # right hand: wrist 4, elbow 3, shoulder 2
157
+ # left hand: wrist 7, elbow 6, shoulder 5
158
+ ratioWristElbow = 0.33
159
+ detect_result = []
160
+ image_height, image_width = oriImg.shape[0:2]
161
+ for person in subset.astype(int):
162
+ # if any of three not detected
163
+ has_left = np.sum(person[[5, 6, 7]] == -1) == 0
164
+ has_right = np.sum(person[[2, 3, 4]] == -1) == 0
165
+ if not (has_left or has_right):
166
+ continue
167
+ hands = []
168
+ #left hand
169
+ if has_left:
170
+ left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
171
+ x1, y1 = candidate[left_shoulder_index][:2]
172
+ x2, y2 = candidate[left_elbow_index][:2]
173
+ x3, y3 = candidate[left_wrist_index][:2]
174
+ hands.append([x1, y1, x2, y2, x3, y3, True])
175
+ # right hand
176
+ if has_right:
177
+ right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
178
+ x1, y1 = candidate[right_shoulder_index][:2]
179
+ x2, y2 = candidate[right_elbow_index][:2]
180
+ x3, y3 = candidate[right_wrist_index][:2]
181
+ hands.append([x1, y1, x2, y2, x3, y3, False])
182
+
183
+ for x1, y1, x2, y2, x3, y3, is_left in hands:
184
+ # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
185
+ # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
186
+ # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
187
+ # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
188
+ # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
189
+ # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
190
+ x = x3 + ratioWristElbow * (x3 - x2)
191
+ y = y3 + ratioWristElbow * (y3 - y2)
192
+ distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
193
+ distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
194
+ width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
195
+ # x-y refers to the center --> offset to topLeft point
196
+ # handRectangle.x -= handRectangle.width / 2.f;
197
+ # handRectangle.y -= handRectangle.height / 2.f;
198
+ x -= width / 2
199
+ y -= width / 2 # width = height
200
+ # overflow the image
201
+ if x < 0: x = 0
202
+ if y < 0: y = 0
203
+ width1 = width
204
+ width2 = width
205
+ if x + width > image_width: width1 = image_width - x
206
+ if y + width > image_height: width2 = image_height - y
207
+ width = min(width1, width2)
208
+ # the max hand box value is 20 pixels
209
+ if width >= 20:
210
+ detect_result.append([int(x), int(y), int(width), is_left])
211
+
212
+ '''
213
+ return value: [[x, y, w, True if left hand else False]].
214
+ width=height since the network require squared input.
215
+ x, y is the coordinate of top left
216
+ '''
217
+ return detect_result
218
+
219
+
220
+ # Written by Lvmin
221
+ def faceDetect(candidate, subset, oriImg):
222
+ # left right eye ear 14 15 16 17
223
+ detect_result = []
224
+ image_height, image_width = oriImg.shape[0:2]
225
+ for person in subset.astype(int):
226
+ has_head = person[0] > -1
227
+ if not has_head:
228
+ continue
229
+
230
+ has_left_eye = person[14] > -1
231
+ has_right_eye = person[15] > -1
232
+ has_left_ear = person[16] > -1
233
+ has_right_ear = person[17] > -1
234
+
235
+ if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
236
+ continue
237
+
238
+ head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
239
+
240
+ width = 0.0
241
+ x0, y0 = candidate[head][:2]
242
+
243
+ if has_left_eye:
244
+ x1, y1 = candidate[left_eye][:2]
245
+ d = max(abs(x0 - x1), abs(y0 - y1))
246
+ width = max(width, d * 3.0)
247
+
248
+ if has_right_eye:
249
+ x1, y1 = candidate[right_eye][:2]
250
+ d = max(abs(x0 - x1), abs(y0 - y1))
251
+ width = max(width, d * 3.0)
252
+
253
+ if has_left_ear:
254
+ x1, y1 = candidate[left_ear][:2]
255
+ d = max(abs(x0 - x1), abs(y0 - y1))
256
+ width = max(width, d * 1.5)
257
+
258
+ if has_right_ear:
259
+ x1, y1 = candidate[right_ear][:2]
260
+ d = max(abs(x0 - x1), abs(y0 - y1))
261
+ width = max(width, d * 1.5)
262
+
263
+ x, y = x0, y0
264
+
265
+ x -= width
266
+ y -= width
267
+
268
+ if x < 0:
269
+ x = 0
270
+
271
+ if y < 0:
272
+ y = 0
273
+
274
+ width1 = width * 2
275
+ width2 = width * 2
276
+
277
+ if x + width > image_width:
278
+ width1 = image_width - x
279
+
280
+ if y + width > image_height:
281
+ width2 = image_height - y
282
+
283
+ width = min(width1, width2)
284
+
285
+ if width >= 20:
286
+ detect_result.append([int(x), int(y), int(width)])
287
+
288
+ return detect_result
289
+
290
+
291
+ # get max index of 2d array
292
+ def npmax(array):
293
+ arrayindex = array.argmax(1)
294
+ arrayvalue = array.max(1)
295
+ i = arrayvalue.argmax()
296
+ j = arrayindex[i]
297
+ return i, j
src/flux/annotator/dwpose/wholebody.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ import onnxruntime as ort
5
+ from huggingface_hub import hf_hub_download
6
+ from .onnxdet import inference_detector
7
+ from .onnxpose import inference_pose
8
+
9
+
10
+ class Wholebody:
11
+ def __init__(self, device="cuda:0"):
12
+ providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']
13
+ onnx_det = hf_hub_download("yzd-v/DWPose", "yolox_l.onnx")
14
+ onnx_pose = hf_hub_download("yzd-v/DWPose", "dw-ll_ucoco_384.onnx")
15
+
16
+ self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
17
+ self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
18
+
19
+ def __call__(self, oriImg):
20
+ det_result = inference_detector(self.session_det, oriImg)
21
+ keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
22
+
23
+ keypoints_info = np.concatenate(
24
+ (keypoints, scores[..., None]), axis=-1)
25
+ # compute neck joint
26
+ neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
27
+ # neck score when visualizing pred
28
+ neck[:, 2:4] = np.logical_and(
29
+ keypoints_info[:, 5, 2:4] > 0.3,
30
+ keypoints_info[:, 6, 2:4] > 0.3).astype(int)
31
+ new_keypoints_info = np.insert(
32
+ keypoints_info, 17, neck, axis=1)
33
+ mmpose_idx = [
34
+ 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
35
+ ]
36
+ openpose_idx = [
37
+ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
38
+ ]
39
+ new_keypoints_info[:, openpose_idx] = \
40
+ new_keypoints_info[:, mmpose_idx]
41
+ keypoints_info = new_keypoints_info
42
+
43
+ keypoints, scores = keypoints_info[
44
+ ..., :2], keypoints_info[..., 2]
45
+
46
+ return keypoints, scores
47
+
48
+
src/flux/annotator/hed/__init__.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an improved version and model of HED edge detection with Apache License, Version 2.0.
2
+ # Please use this implementation in your products
3
+ # This implementation may produce slightly different results from Saining Xie's official implementations,
4
+ # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
5
+ # Different from official models and other implementations, this is an RGB-input model (rather than BGR)
6
+ # and in this way it works better for gradio's RGB protocol
7
+
8
+ import os
9
+ import cv2
10
+ import torch
11
+ import numpy as np
12
+
13
+ from huggingface_hub import hf_hub_download
14
+ from einops import rearrange
15
+ from ...annotator.util import annotator_ckpts_path
16
+
17
+
18
+ class DoubleConvBlock(torch.nn.Module):
19
+ def __init__(self, input_channel, output_channel, layer_number):
20
+ super().__init__()
21
+ self.convs = torch.nn.Sequential()
22
+ self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
23
+ for i in range(1, layer_number):
24
+ self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
25
+ self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
26
+
27
+ def __call__(self, x, down_sampling=False):
28
+ h = x
29
+ if down_sampling:
30
+ h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
31
+ for conv in self.convs:
32
+ h = conv(h)
33
+ h = torch.nn.functional.relu(h)
34
+ return h, self.projection(h)
35
+
36
+
37
+ class ControlNetHED_Apache2(torch.nn.Module):
38
+ def __init__(self):
39
+ super().__init__()
40
+ self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
41
+ self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
42
+ self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
43
+ self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
44
+ self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
45
+ self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
46
+
47
+ def __call__(self, x):
48
+ h = x - self.norm
49
+ h, projection1 = self.block1(h)
50
+ h, projection2 = self.block2(h, down_sampling=True)
51
+ h, projection3 = self.block3(h, down_sampling=True)
52
+ h, projection4 = self.block4(h, down_sampling=True)
53
+ h, projection5 = self.block5(h, down_sampling=True)
54
+ return projection1, projection2, projection3, projection4, projection5
55
+
56
+
57
+ class HEDdetector:
58
+ def __init__(self):
59
+ modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth")
60
+ if not os.path.exists(modelpath):
61
+ modelpath = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth")
62
+ self.netNetwork = ControlNetHED_Apache2().float().cuda().eval()
63
+ self.netNetwork.load_state_dict(torch.load(modelpath))
64
+
65
+ def __call__(self, input_image):
66
+ assert input_image.ndim == 3
67
+ H, W, C = input_image.shape
68
+ with torch.no_grad():
69
+ image_hed = torch.from_numpy(input_image.copy()).float().cuda()
70
+ image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
71
+ edges = self.netNetwork(image_hed)
72
+ edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
73
+ edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
74
+ edges = np.stack(edges, axis=2)
75
+ edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
76
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
77
+ return edge
78
+
79
+
80
+ def nms(x, t, s):
81
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
82
+
83
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
84
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
85
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
86
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
87
+
88
+ y = np.zeros_like(x)
89
+
90
+ for f in [f1, f2, f3, f4]:
91
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
92
+
93
+ z = np.zeros_like(y, dtype=np.uint8)
94
+ z[y > t] = 255
95
+ return z
src/flux/annotator/midas/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
src/flux/annotator/midas/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Midas Depth Estimation
2
+ # From https://github.com/isl-org/MiDaS
3
+ # MIT LICENSE
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+
9
+ from einops import rearrange
10
+ from .api import MiDaSInference
11
+
12
+
13
+ class MidasDetector:
14
+ def __init__(self):
15
+ self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
16
+
17
+ def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
18
+ assert input_image.ndim == 3
19
+ image_depth = input_image
20
+ with torch.no_grad():
21
+ image_depth = torch.from_numpy(image_depth).float().cuda()
22
+ image_depth = image_depth / 127.5 - 1.0
23
+ image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
24
+ depth = self.model(image_depth)[0]
25
+
26
+ depth_pt = depth.clone()
27
+ depth_pt -= torch.min(depth_pt)
28
+ depth_pt /= torch.max(depth_pt)
29
+ depth_pt = depth_pt.cpu().numpy()
30
+ depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
31
+
32
+ depth_np = depth.cpu().numpy()
33
+ x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
34
+ y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
35
+ z = np.ones_like(x) * a
36
+ x[depth_pt < bg_th] = 0
37
+ y[depth_pt < bg_th] = 0
38
+ normal = np.stack([x, y, z], axis=2)
39
+ normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
40
+ normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
41
+
42
+ return depth_image, normal_image
src/flux/annotator/midas/api.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/isl-org/MiDaS
2
+
3
+ import cv2
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision.transforms import Compose
8
+
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ from .midas.dpt_depth import DPTDepthModel
12
+ from .midas.midas_net import MidasNet
13
+ from .midas.midas_net_custom import MidasNet_small
14
+ from .midas.transforms import Resize, NormalizeImage, PrepareForNet
15
+ from ...annotator.util import annotator_ckpts_path
16
+
17
+
18
+ ISL_PATHS = {
19
+ "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
20
+ "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
21
+ "midas_v21": "",
22
+ "midas_v21_small": "",
23
+ }
24
+
25
+
26
+ def disabled_train(self, mode=True):
27
+ """Overwrite model.train with this function to make sure train/eval mode
28
+ does not change anymore."""
29
+ return self
30
+
31
+
32
+ def load_midas_transform(model_type):
33
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
34
+ # load transform only
35
+ if model_type == "dpt_large": # DPT-Large
36
+ net_w, net_h = 384, 384
37
+ resize_mode = "minimal"
38
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
39
+
40
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
41
+ net_w, net_h = 384, 384
42
+ resize_mode = "minimal"
43
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
44
+
45
+ elif model_type == "midas_v21":
46
+ net_w, net_h = 384, 384
47
+ resize_mode = "upper_bound"
48
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
49
+
50
+ elif model_type == "midas_v21_small":
51
+ net_w, net_h = 256, 256
52
+ resize_mode = "upper_bound"
53
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
54
+
55
+ else:
56
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
57
+
58
+ transform = Compose(
59
+ [
60
+ Resize(
61
+ net_w,
62
+ net_h,
63
+ resize_target=None,
64
+ keep_aspect_ratio=True,
65
+ ensure_multiple_of=32,
66
+ resize_method=resize_mode,
67
+ image_interpolation_method=cv2.INTER_CUBIC,
68
+ ),
69
+ normalization,
70
+ PrepareForNet(),
71
+ ]
72
+ )
73
+
74
+ return transform
75
+
76
+
77
+ def load_model(model_type):
78
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
79
+ # load network
80
+ model_path = ISL_PATHS[model_type]
81
+ if model_type == "dpt_large": # DPT-Large
82
+ model = DPTDepthModel(
83
+ path=model_path,
84
+ backbone="vitl16_384",
85
+ non_negative=True,
86
+ )
87
+ net_w, net_h = 384, 384
88
+ resize_mode = "minimal"
89
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
90
+
91
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
92
+ if not os.path.exists(model_path):
93
+ model_path = hf_hub_download("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt")
94
+
95
+ model = DPTDepthModel(
96
+ path=model_path,
97
+ backbone="vitb_rn50_384",
98
+ non_negative=True,
99
+ )
100
+ net_w, net_h = 384, 384
101
+ resize_mode = "minimal"
102
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
103
+
104
+ elif model_type == "midas_v21":
105
+ model = MidasNet(model_path, non_negative=True)
106
+ net_w, net_h = 384, 384
107
+ resize_mode = "upper_bound"
108
+ normalization = NormalizeImage(
109
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
110
+ )
111
+
112
+ elif model_type == "midas_v21_small":
113
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
114
+ non_negative=True, blocks={'expand': True})
115
+ net_w, net_h = 256, 256
116
+ resize_mode = "upper_bound"
117
+ normalization = NormalizeImage(
118
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
119
+ )
120
+
121
+ else:
122
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
123
+ assert False
124
+
125
+ transform = Compose(
126
+ [
127
+ Resize(
128
+ net_w,
129
+ net_h,
130
+ resize_target=None,
131
+ keep_aspect_ratio=True,
132
+ ensure_multiple_of=32,
133
+ resize_method=resize_mode,
134
+ image_interpolation_method=cv2.INTER_CUBIC,
135
+ ),
136
+ normalization,
137
+ PrepareForNet(),
138
+ ]
139
+ )
140
+
141
+ return model.eval(), transform
142
+
143
+
144
+ class MiDaSInference(nn.Module):
145
+ MODEL_TYPES_TORCH_HUB = [
146
+ "DPT_Large",
147
+ "DPT_Hybrid",
148
+ "MiDaS_small"
149
+ ]
150
+ MODEL_TYPES_ISL = [
151
+ "dpt_large",
152
+ "dpt_hybrid",
153
+ "midas_v21",
154
+ "midas_v21_small",
155
+ ]
156
+
157
+ def __init__(self, model_type):
158
+ super().__init__()
159
+ assert (model_type in self.MODEL_TYPES_ISL)
160
+ model, _ = load_model(model_type)
161
+ self.model = model
162
+ self.model.train = disabled_train
163
+
164
+ def forward(self, x):
165
+ with torch.no_grad():
166
+ prediction = self.model(x)
167
+ return prediction
168
+
src/flux/annotator/midas/midas/__init__.py ADDED
File without changes
src/flux/annotator/midas/midas/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device('cpu'))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
src/flux/annotator/midas/midas/blocks.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vit import (
5
+ _make_pretrained_vitb_rn50_384,
6
+ _make_pretrained_vitl16_384,
7
+ _make_pretrained_vitb16_384,
8
+ forward_vit,
9
+ )
10
+
11
+ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12
+ if backbone == "vitl16_384":
13
+ pretrained = _make_pretrained_vitl16_384(
14
+ use_pretrained, hooks=hooks, use_readout=use_readout
15
+ )
16
+ scratch = _make_scratch(
17
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
18
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
19
+ elif backbone == "vitb_rn50_384":
20
+ pretrained = _make_pretrained_vitb_rn50_384(
21
+ use_pretrained,
22
+ hooks=hooks,
23
+ use_vit_only=use_vit_only,
24
+ use_readout=use_readout,
25
+ )
26
+ scratch = _make_scratch(
27
+ [256, 512, 768, 768], features, groups=groups, expand=expand
28
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
29
+ elif backbone == "vitb16_384":
30
+ pretrained = _make_pretrained_vitb16_384(
31
+ use_pretrained, hooks=hooks, use_readout=use_readout
32
+ )
33
+ scratch = _make_scratch(
34
+ [96, 192, 384, 768], features, groups=groups, expand=expand
35
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
36
+ elif backbone == "resnext101_wsl":
37
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39
+ elif backbone == "efficientnet_lite3":
40
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42
+ else:
43
+ print(f"Backbone '{backbone}' not implemented")
44
+ assert False
45
+
46
+ return pretrained, scratch
47
+
48
+
49
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50
+ scratch = nn.Module()
51
+
52
+ out_shape1 = out_shape
53
+ out_shape2 = out_shape
54
+ out_shape3 = out_shape
55
+ out_shape4 = out_shape
56
+ if expand==True:
57
+ out_shape1 = out_shape
58
+ out_shape2 = out_shape*2
59
+ out_shape3 = out_shape*4
60
+ out_shape4 = out_shape*8
61
+
62
+ scratch.layer1_rn = nn.Conv2d(
63
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64
+ )
65
+ scratch.layer2_rn = nn.Conv2d(
66
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67
+ )
68
+ scratch.layer3_rn = nn.Conv2d(
69
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70
+ )
71
+ scratch.layer4_rn = nn.Conv2d(
72
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73
+ )
74
+
75
+ return scratch
76
+
77
+
78
+ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79
+ efficientnet = torch.hub.load(
80
+ "rwightman/gen-efficientnet-pytorch",
81
+ "tf_efficientnet_lite3",
82
+ pretrained=use_pretrained,
83
+ exportable=exportable
84
+ )
85
+ return _make_efficientnet_backbone(efficientnet)
86
+
87
+
88
+ def _make_efficientnet_backbone(effnet):
89
+ pretrained = nn.Module()
90
+
91
+ pretrained.layer1 = nn.Sequential(
92
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93
+ )
94
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97
+
98
+ return pretrained
99
+
100
+
101
+ def _make_resnet_backbone(resnet):
102
+ pretrained = nn.Module()
103
+ pretrained.layer1 = nn.Sequential(
104
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105
+ )
106
+
107
+ pretrained.layer2 = resnet.layer2
108
+ pretrained.layer3 = resnet.layer3
109
+ pretrained.layer4 = resnet.layer4
110
+
111
+ return pretrained
112
+
113
+
114
+ def _make_pretrained_resnext101_wsl(use_pretrained):
115
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116
+ return _make_resnet_backbone(resnet)
117
+
118
+
119
+
120
+ class Interpolate(nn.Module):
121
+ """Interpolation module.
122
+ """
123
+
124
+ def __init__(self, scale_factor, mode, align_corners=False):
125
+ """Init.
126
+
127
+ Args:
128
+ scale_factor (float): scaling
129
+ mode (str): interpolation mode
130
+ """
131
+ super(Interpolate, self).__init__()
132
+
133
+ self.interp = nn.functional.interpolate
134
+ self.scale_factor = scale_factor
135
+ self.mode = mode
136
+ self.align_corners = align_corners
137
+
138
+ def forward(self, x):
139
+ """Forward pass.
140
+
141
+ Args:
142
+ x (tensor): input
143
+
144
+ Returns:
145
+ tensor: interpolated data
146
+ """
147
+
148
+ x = self.interp(
149
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150
+ )
151
+
152
+ return x
153
+
154
+
155
+ class ResidualConvUnit(nn.Module):
156
+ """Residual convolution module.
157
+ """
158
+
159
+ def __init__(self, features):
160
+ """Init.
161
+
162
+ Args:
163
+ features (int): number of features
164
+ """
165
+ super().__init__()
166
+
167
+ self.conv1 = nn.Conv2d(
168
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
169
+ )
170
+
171
+ self.conv2 = nn.Conv2d(
172
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
173
+ )
174
+
175
+ self.relu = nn.ReLU(inplace=True)
176
+
177
+ def forward(self, x):
178
+ """Forward pass.
179
+
180
+ Args:
181
+ x (tensor): input
182
+
183
+ Returns:
184
+ tensor: output
185
+ """
186
+ out = self.relu(x)
187
+ out = self.conv1(out)
188
+ out = self.relu(out)
189
+ out = self.conv2(out)
190
+
191
+ return out + x
192
+
193
+
194
+ class FeatureFusionBlock(nn.Module):
195
+ """Feature fusion block.
196
+ """
197
+
198
+ def __init__(self, features):
199
+ """Init.
200
+
201
+ Args:
202
+ features (int): number of features
203
+ """
204
+ super(FeatureFusionBlock, self).__init__()
205
+
206
+ self.resConfUnit1 = ResidualConvUnit(features)
207
+ self.resConfUnit2 = ResidualConvUnit(features)
208
+
209
+ def forward(self, *xs):
210
+ """Forward pass.
211
+
212
+ Returns:
213
+ tensor: output
214
+ """
215
+ output = xs[0]
216
+
217
+ if len(xs) == 2:
218
+ output += self.resConfUnit1(xs[1])
219
+
220
+ output = self.resConfUnit2(output)
221
+
222
+ output = nn.functional.interpolate(
223
+ output, scale_factor=2, mode="bilinear", align_corners=True
224
+ )
225
+
226
+ return output
227
+
228
+
229
+
230
+
231
+ class ResidualConvUnit_custom(nn.Module):
232
+ """Residual convolution module.
233
+ """
234
+
235
+ def __init__(self, features, activation, bn):
236
+ """Init.
237
+
238
+ Args:
239
+ features (int): number of features
240
+ """
241
+ super().__init__()
242
+
243
+ self.bn = bn
244
+
245
+ self.groups=1
246
+
247
+ self.conv1 = nn.Conv2d(
248
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249
+ )
250
+
251
+ self.conv2 = nn.Conv2d(
252
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253
+ )
254
+
255
+ if self.bn==True:
256
+ self.bn1 = nn.BatchNorm2d(features)
257
+ self.bn2 = nn.BatchNorm2d(features)
258
+
259
+ self.activation = activation
260
+
261
+ self.skip_add = nn.quantized.FloatFunctional()
262
+
263
+ def forward(self, x):
264
+ """Forward pass.
265
+
266
+ Args:
267
+ x (tensor): input
268
+
269
+ Returns:
270
+ tensor: output
271
+ """
272
+
273
+ out = self.activation(x)
274
+ out = self.conv1(out)
275
+ if self.bn==True:
276
+ out = self.bn1(out)
277
+
278
+ out = self.activation(out)
279
+ out = self.conv2(out)
280
+ if self.bn==True:
281
+ out = self.bn2(out)
282
+
283
+ if self.groups > 1:
284
+ out = self.conv_merge(out)
285
+
286
+ return self.skip_add.add(out, x)
287
+
288
+ # return out + x
289
+
290
+
291
+ class FeatureFusionBlock_custom(nn.Module):
292
+ """Feature fusion block.
293
+ """
294
+
295
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296
+ """Init.
297
+
298
+ Args:
299
+ features (int): number of features
300
+ """
301
+ super(FeatureFusionBlock_custom, self).__init__()
302
+
303
+ self.deconv = deconv
304
+ self.align_corners = align_corners
305
+
306
+ self.groups=1
307
+
308
+ self.expand = expand
309
+ out_features = features
310
+ if self.expand==True:
311
+ out_features = features//2
312
+
313
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314
+
315
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317
+
318
+ self.skip_add = nn.quantized.FloatFunctional()
319
+
320
+ def forward(self, *xs):
321
+ """Forward pass.
322
+
323
+ Returns:
324
+ tensor: output
325
+ """
326
+ output = xs[0]
327
+
328
+ if len(xs) == 2:
329
+ res = self.resConfUnit1(xs[1])
330
+ output = self.skip_add.add(output, res)
331
+ # output += res
332
+
333
+ output = self.resConfUnit2(output)
334
+
335
+ output = nn.functional.interpolate(
336
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337
+ )
338
+
339
+ output = self.out_conv(output)
340
+
341
+ return output
342
+
src/flux/annotator/midas/midas/dpt_depth.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .base_model import BaseModel
6
+ from .blocks import (
7
+ FeatureFusionBlock,
8
+ FeatureFusionBlock_custom,
9
+ Interpolate,
10
+ _make_encoder,
11
+ forward_vit,
12
+ )
13
+
14
+
15
+ def _make_fusion_block(features, use_bn):
16
+ return FeatureFusionBlock_custom(
17
+ features,
18
+ nn.ReLU(False),
19
+ deconv=False,
20
+ bn=use_bn,
21
+ expand=False,
22
+ align_corners=True,
23
+ )
24
+
25
+
26
+ class DPT(BaseModel):
27
+ def __init__(
28
+ self,
29
+ head,
30
+ features=256,
31
+ backbone="vitb_rn50_384",
32
+ readout="project",
33
+ channels_last=False,
34
+ use_bn=False,
35
+ ):
36
+
37
+ super(DPT, self).__init__()
38
+
39
+ self.channels_last = channels_last
40
+
41
+ hooks = {
42
+ "vitb_rn50_384": [0, 1, 8, 11],
43
+ "vitb16_384": [2, 5, 8, 11],
44
+ "vitl16_384": [5, 11, 17, 23],
45
+ }
46
+
47
+ # Instantiate backbone and reassemble blocks
48
+ self.pretrained, self.scratch = _make_encoder(
49
+ backbone,
50
+ features,
51
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
52
+ groups=1,
53
+ expand=False,
54
+ exportable=False,
55
+ hooks=hooks[backbone],
56
+ use_readout=readout,
57
+ )
58
+
59
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63
+
64
+ self.scratch.output_conv = head
65
+
66
+
67
+ def forward(self, x):
68
+ if self.channels_last == True:
69
+ x.contiguous(memory_format=torch.channels_last)
70
+
71
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72
+
73
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
74
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
75
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
76
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
77
+
78
+ path_4 = self.scratch.refinenet4(layer_4_rn)
79
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82
+
83
+ out = self.scratch.output_conv(path_1)
84
+
85
+ return out
86
+
87
+
88
+ class DPTDepthModel(DPT):
89
+ def __init__(self, path=None, non_negative=True, **kwargs):
90
+ features = kwargs["features"] if "features" in kwargs else 256
91
+
92
+ head = nn.Sequential(
93
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96
+ nn.ReLU(True),
97
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98
+ nn.ReLU(True) if non_negative else nn.Identity(),
99
+ nn.Identity(),
100
+ )
101
+
102
+ super().__init__(head, **kwargs)
103
+
104
+ if path is not None:
105
+ self.load(path)
106
+
107
+ def forward(self, x):
108
+ return super().forward(x).squeeze(dim=1)
109
+
src/flux/annotator/midas/midas/midas_net.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=256, non_negative=True):
17
+ """Init.
18
+
19
+ Args:
20
+ path (str, optional): Path to saved model. Defaults to None.
21
+ features (int, optional): Number of features. Defaults to 256.
22
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23
+ """
24
+ print("Loading weights: ", path)
25
+
26
+ super(MidasNet, self).__init__()
27
+
28
+ use_pretrained = False if path is None else True
29
+
30
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31
+
32
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
33
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
34
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
35
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
36
+
37
+ self.scratch.output_conv = nn.Sequential(
38
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39
+ Interpolate(scale_factor=2, mode="bilinear"),
40
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41
+ nn.ReLU(True),
42
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43
+ nn.ReLU(True) if non_negative else nn.Identity(),
44
+ )
45
+
46
+ if path:
47
+ self.load(path)
48
+
49
+ def forward(self, x):
50
+ """Forward pass.
51
+
52
+ Args:
53
+ x (tensor): input data (image)
54
+
55
+ Returns:
56
+ tensor: depth
57
+ """
58
+
59
+ layer_1 = self.pretrained.layer1(x)
60
+ layer_2 = self.pretrained.layer2(layer_1)
61
+ layer_3 = self.pretrained.layer3(layer_2)
62
+ layer_4 = self.pretrained.layer4(layer_3)
63
+
64
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
65
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
66
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
67
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
68
+
69
+ path_4 = self.scratch.refinenet4(layer_4_rn)
70
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73
+
74
+ out = self.scratch.output_conv(path_1)
75
+
76
+ return torch.squeeze(out, dim=1)
src/flux/annotator/midas/midas/midas_net_custom.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet_small(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17
+ blocks={'expand': True}):
18
+ """Init.
19
+
20
+ Args:
21
+ path (str, optional): Path to saved model. Defaults to None.
22
+ features (int, optional): Number of features. Defaults to 256.
23
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24
+ """
25
+ print("Loading weights: ", path)
26
+
27
+ super(MidasNet_small, self).__init__()
28
+
29
+ use_pretrained = False if path else True
30
+
31
+ self.channels_last = channels_last
32
+ self.blocks = blocks
33
+ self.backbone = backbone
34
+
35
+ self.groups = 1
36
+
37
+ features1=features
38
+ features2=features
39
+ features3=features
40
+ features4=features
41
+ self.expand = False
42
+ if "expand" in self.blocks and self.blocks['expand'] == True:
43
+ self.expand = True
44
+ features1=features
45
+ features2=features*2
46
+ features3=features*4
47
+ features4=features*8
48
+
49
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50
+
51
+ self.scratch.activation = nn.ReLU(False)
52
+
53
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57
+
58
+
59
+ self.scratch.output_conv = nn.Sequential(
60
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61
+ Interpolate(scale_factor=2, mode="bilinear"),
62
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63
+ self.scratch.activation,
64
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65
+ nn.ReLU(True) if non_negative else nn.Identity(),
66
+ nn.Identity(),
67
+ )
68
+
69
+ if path:
70
+ self.load(path)
71
+
72
+
73
+ def forward(self, x):
74
+ """Forward pass.
75
+
76
+ Args:
77
+ x (tensor): input data (image)
78
+
79
+ Returns:
80
+ tensor: depth
81
+ """
82
+ if self.channels_last==True:
83
+ print("self.channels_last = ", self.channels_last)
84
+ x.contiguous(memory_format=torch.channels_last)
85
+
86
+
87
+ layer_1 = self.pretrained.layer1(x)
88
+ layer_2 = self.pretrained.layer2(layer_1)
89
+ layer_3 = self.pretrained.layer3(layer_2)
90
+ layer_4 = self.pretrained.layer4(layer_3)
91
+
92
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
93
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
94
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
95
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
96
+
97
+
98
+ path_4 = self.scratch.refinenet4(layer_4_rn)
99
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102
+
103
+ out = self.scratch.output_conv(path_1)
104
+
105
+ return torch.squeeze(out, dim=1)
106
+
107
+
108
+
109
+ def fuse_model(m):
110
+ prev_previous_type = nn.Identity()
111
+ prev_previous_name = ''
112
+ previous_type = nn.Identity()
113
+ previous_name = ''
114
+ for name, module in m.named_modules():
115
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116
+ # print("FUSED ", prev_previous_name, previous_name, name)
117
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119
+ # print("FUSED ", prev_previous_name, previous_name)
120
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122
+ # print("FUSED ", previous_name, name)
123
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124
+
125
+ prev_previous_type = previous_type
126
+ prev_previous_name = previous_name
127
+ previous_type = type(module)
128
+ previous_name = name
src/flux/annotator/midas/midas/transforms.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import math
4
+
5
+
6
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
8
+
9
+ Args:
10
+ sample (dict): sample
11
+ size (tuple): image size
12
+
13
+ Returns:
14
+ tuple: new size
15
+ """
16
+ shape = list(sample["disparity"].shape)
17
+
18
+ if shape[0] >= size[0] and shape[1] >= size[1]:
19
+ return sample
20
+
21
+ scale = [0, 0]
22
+ scale[0] = size[0] / shape[0]
23
+ scale[1] = size[1] / shape[1]
24
+
25
+ scale = max(scale)
26
+
27
+ shape[0] = math.ceil(scale * shape[0])
28
+ shape[1] = math.ceil(scale * shape[1])
29
+
30
+ # resize
31
+ sample["image"] = cv2.resize(
32
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33
+ )
34
+
35
+ sample["disparity"] = cv2.resize(
36
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37
+ )
38
+ sample["mask"] = cv2.resize(
39
+ sample["mask"].astype(np.float32),
40
+ tuple(shape[::-1]),
41
+ interpolation=cv2.INTER_NEAREST,
42
+ )
43
+ sample["mask"] = sample["mask"].astype(bool)
44
+
45
+ return tuple(shape)
46
+
47
+
48
+ class Resize(object):
49
+ """Resize sample to given size (width, height).
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ width,
55
+ height,
56
+ resize_target=True,
57
+ keep_aspect_ratio=False,
58
+ ensure_multiple_of=1,
59
+ resize_method="lower_bound",
60
+ image_interpolation_method=cv2.INTER_AREA,
61
+ ):
62
+ """Init.
63
+
64
+ Args:
65
+ width (int): desired output width
66
+ height (int): desired output height
67
+ resize_target (bool, optional):
68
+ True: Resize the full sample (image, mask, target).
69
+ False: Resize image only.
70
+ Defaults to True.
71
+ keep_aspect_ratio (bool, optional):
72
+ True: Keep the aspect ratio of the input sample.
73
+ Output sample might not have the given width and height, and
74
+ resize behaviour depends on the parameter 'resize_method'.
75
+ Defaults to False.
76
+ ensure_multiple_of (int, optional):
77
+ Output width and height is constrained to be multiple of this parameter.
78
+ Defaults to 1.
79
+ resize_method (str, optional):
80
+ "lower_bound": Output will be at least as large as the given size.
81
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83
+ Defaults to "lower_bound".
84
+ """
85
+ self.__width = width
86
+ self.__height = height
87
+
88
+ self.__resize_target = resize_target
89
+ self.__keep_aspect_ratio = keep_aspect_ratio
90
+ self.__multiple_of = ensure_multiple_of
91
+ self.__resize_method = resize_method
92
+ self.__image_interpolation_method = image_interpolation_method
93
+
94
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96
+
97
+ if max_val is not None and y > max_val:
98
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99
+
100
+ if y < min_val:
101
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102
+
103
+ return y
104
+
105
+ def get_size(self, width, height):
106
+ # determine new height and width
107
+ scale_height = self.__height / height
108
+ scale_width = self.__width / width
109
+
110
+ if self.__keep_aspect_ratio:
111
+ if self.__resize_method == "lower_bound":
112
+ # scale such that output size is lower bound
113
+ if scale_width > scale_height:
114
+ # fit width
115
+ scale_height = scale_width
116
+ else:
117
+ # fit height
118
+ scale_width = scale_height
119
+ elif self.__resize_method == "upper_bound":
120
+ # scale such that output size is upper bound
121
+ if scale_width < scale_height:
122
+ # fit width
123
+ scale_height = scale_width
124
+ else:
125
+ # fit height
126
+ scale_width = scale_height
127
+ elif self.__resize_method == "minimal":
128
+ # scale as least as possbile
129
+ if abs(1 - scale_width) < abs(1 - scale_height):
130
+ # fit width
131
+ scale_height = scale_width
132
+ else:
133
+ # fit height
134
+ scale_width = scale_height
135
+ else:
136
+ raise ValueError(
137
+ f"resize_method {self.__resize_method} not implemented"
138
+ )
139
+
140
+ if self.__resize_method == "lower_bound":
141
+ new_height = self.constrain_to_multiple_of(
142
+ scale_height * height, min_val=self.__height
143
+ )
144
+ new_width = self.constrain_to_multiple_of(
145
+ scale_width * width, min_val=self.__width
146
+ )
147
+ elif self.__resize_method == "upper_bound":
148
+ new_height = self.constrain_to_multiple_of(
149
+ scale_height * height, max_val=self.__height
150
+ )
151
+ new_width = self.constrain_to_multiple_of(
152
+ scale_width * width, max_val=self.__width
153
+ )
154
+ elif self.__resize_method == "minimal":
155
+ new_height = self.constrain_to_multiple_of(scale_height * height)
156
+ new_width = self.constrain_to_multiple_of(scale_width * width)
157
+ else:
158
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
159
+
160
+ return (new_width, new_height)
161
+
162
+ def __call__(self, sample):
163
+ width, height = self.get_size(
164
+ sample["image"].shape[1], sample["image"].shape[0]
165
+ )
166
+
167
+ # resize sample
168
+ sample["image"] = cv2.resize(
169
+ sample["image"],
170
+ (width, height),
171
+ interpolation=self.__image_interpolation_method,
172
+ )
173
+
174
+ if self.__resize_target:
175
+ if "disparity" in sample:
176
+ sample["disparity"] = cv2.resize(
177
+ sample["disparity"],
178
+ (width, height),
179
+ interpolation=cv2.INTER_NEAREST,
180
+ )
181
+
182
+ if "depth" in sample:
183
+ sample["depth"] = cv2.resize(
184
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185
+ )
186
+
187
+ sample["mask"] = cv2.resize(
188
+ sample["mask"].astype(np.float32),
189
+ (width, height),
190
+ interpolation=cv2.INTER_NEAREST,
191
+ )
192
+ sample["mask"] = sample["mask"].astype(bool)
193
+
194
+ return sample
195
+
196
+
197
+ class NormalizeImage(object):
198
+ """Normlize image by given mean and std.
199
+ """
200
+
201
+ def __init__(self, mean, std):
202
+ self.__mean = mean
203
+ self.__std = std
204
+
205
+ def __call__(self, sample):
206
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
207
+
208
+ return sample
209
+
210
+
211
+ class PrepareForNet(object):
212
+ """Prepare sample for usage as network input.
213
+ """
214
+
215
+ def __init__(self):
216
+ pass
217
+
218
+ def __call__(self, sample):
219
+ image = np.transpose(sample["image"], (2, 0, 1))
220
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221
+
222
+ if "mask" in sample:
223
+ sample["mask"] = sample["mask"].astype(np.float32)
224
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
225
+
226
+ if "disparity" in sample:
227
+ disparity = sample["disparity"].astype(np.float32)
228
+ sample["disparity"] = np.ascontiguousarray(disparity)
229
+
230
+ if "depth" in sample:
231
+ depth = sample["depth"].astype(np.float32)
232
+ sample["depth"] = np.ascontiguousarray(depth)
233
+
234
+ return sample
src/flux/annotator/midas/midas/vit.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class Slice(nn.Module):
10
+ def __init__(self, start_index=1):
11
+ super(Slice, self).__init__()
12
+ self.start_index = start_index
13
+
14
+ def forward(self, x):
15
+ return x[:, self.start_index :]
16
+
17
+
18
+ class AddReadout(nn.Module):
19
+ def __init__(self, start_index=1):
20
+ super(AddReadout, self).__init__()
21
+ self.start_index = start_index
22
+
23
+ def forward(self, x):
24
+ if self.start_index == 2:
25
+ readout = (x[:, 0] + x[:, 1]) / 2
26
+ else:
27
+ readout = x[:, 0]
28
+ return x[:, self.start_index :] + readout.unsqueeze(1)
29
+
30
+
31
+ class ProjectReadout(nn.Module):
32
+ def __init__(self, in_features, start_index=1):
33
+ super(ProjectReadout, self).__init__()
34
+ self.start_index = start_index
35
+
36
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
37
+
38
+ def forward(self, x):
39
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
40
+ features = torch.cat((x[:, self.start_index :], readout), -1)
41
+
42
+ return self.project(features)
43
+
44
+
45
+ class Transpose(nn.Module):
46
+ def __init__(self, dim0, dim1):
47
+ super(Transpose, self).__init__()
48
+ self.dim0 = dim0
49
+ self.dim1 = dim1
50
+
51
+ def forward(self, x):
52
+ x = x.transpose(self.dim0, self.dim1)
53
+ return x
54
+
55
+
56
+ def forward_vit(pretrained, x):
57
+ b, c, h, w = x.shape
58
+
59
+ glob = pretrained.model.forward_flex(x)
60
+
61
+ layer_1 = pretrained.activations["1"]
62
+ layer_2 = pretrained.activations["2"]
63
+ layer_3 = pretrained.activations["3"]
64
+ layer_4 = pretrained.activations["4"]
65
+
66
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
67
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
68
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
69
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
70
+
71
+ unflatten = nn.Sequential(
72
+ nn.Unflatten(
73
+ 2,
74
+ torch.Size(
75
+ [
76
+ h // pretrained.model.patch_size[1],
77
+ w // pretrained.model.patch_size[0],
78
+ ]
79
+ ),
80
+ )
81
+ )
82
+
83
+ if layer_1.ndim == 3:
84
+ layer_1 = unflatten(layer_1)
85
+ if layer_2.ndim == 3:
86
+ layer_2 = unflatten(layer_2)
87
+ if layer_3.ndim == 3:
88
+ layer_3 = unflatten(layer_3)
89
+ if layer_4.ndim == 3:
90
+ layer_4 = unflatten(layer_4)
91
+
92
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
93
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
94
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
95
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
96
+
97
+ return layer_1, layer_2, layer_3, layer_4
98
+
99
+
100
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
101
+ posemb_tok, posemb_grid = (
102
+ posemb[:, : self.start_index],
103
+ posemb[0, self.start_index :],
104
+ )
105
+
106
+ gs_old = int(math.sqrt(len(posemb_grid)))
107
+
108
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
109
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
110
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
111
+
112
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
113
+
114
+ return posemb
115
+
116
+
117
+ def forward_flex(self, x):
118
+ b, c, h, w = x.shape
119
+
120
+ pos_embed = self._resize_pos_embed(
121
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
122
+ )
123
+
124
+ B = x.shape[0]
125
+
126
+ if hasattr(self.patch_embed, "backbone"):
127
+ x = self.patch_embed.backbone(x)
128
+ if isinstance(x, (list, tuple)):
129
+ x = x[-1] # last feature if backbone outputs list/tuple of features
130
+
131
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
132
+
133
+ if getattr(self, "dist_token", None) is not None:
134
+ cls_tokens = self.cls_token.expand(
135
+ B, -1, -1
136
+ ) # stole cls_tokens impl from Phil Wang, thanks
137
+ dist_token = self.dist_token.expand(B, -1, -1)
138
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
139
+ else:
140
+ cls_tokens = self.cls_token.expand(
141
+ B, -1, -1
142
+ ) # stole cls_tokens impl from Phil Wang, thanks
143
+ x = torch.cat((cls_tokens, x), dim=1)
144
+
145
+ x = x + pos_embed
146
+ x = self.pos_drop(x)
147
+
148
+ for blk in self.blocks:
149
+ x = blk(x)
150
+
151
+ x = self.norm(x)
152
+
153
+ return x
154
+
155
+
156
+ activations = {}
157
+
158
+
159
+ def get_activation(name):
160
+ def hook(model, input, output):
161
+ activations[name] = output
162
+
163
+ return hook
164
+
165
+
166
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
167
+ if use_readout == "ignore":
168
+ readout_oper = [Slice(start_index)] * len(features)
169
+ elif use_readout == "add":
170
+ readout_oper = [AddReadout(start_index)] * len(features)
171
+ elif use_readout == "project":
172
+ readout_oper = [
173
+ ProjectReadout(vit_features, start_index) for out_feat in features
174
+ ]
175
+ else:
176
+ assert (
177
+ False
178
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
179
+
180
+ return readout_oper
181
+
182
+
183
+ def _make_vit_b16_backbone(
184
+ model,
185
+ features=[96, 192, 384, 768],
186
+ size=[384, 384],
187
+ hooks=[2, 5, 8, 11],
188
+ vit_features=768,
189
+ use_readout="ignore",
190
+ start_index=1,
191
+ ):
192
+ pretrained = nn.Module()
193
+
194
+ pretrained.model = model
195
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
196
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
197
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
198
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
199
+
200
+ pretrained.activations = activations
201
+
202
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
203
+
204
+ # 32, 48, 136, 384
205
+ pretrained.act_postprocess1 = nn.Sequential(
206
+ readout_oper[0],
207
+ Transpose(1, 2),
208
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
209
+ nn.Conv2d(
210
+ in_channels=vit_features,
211
+ out_channels=features[0],
212
+ kernel_size=1,
213
+ stride=1,
214
+ padding=0,
215
+ ),
216
+ nn.ConvTranspose2d(
217
+ in_channels=features[0],
218
+ out_channels=features[0],
219
+ kernel_size=4,
220
+ stride=4,
221
+ padding=0,
222
+ bias=True,
223
+ dilation=1,
224
+ groups=1,
225
+ ),
226
+ )
227
+
228
+ pretrained.act_postprocess2 = nn.Sequential(
229
+ readout_oper[1],
230
+ Transpose(1, 2),
231
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
232
+ nn.Conv2d(
233
+ in_channels=vit_features,
234
+ out_channels=features[1],
235
+ kernel_size=1,
236
+ stride=1,
237
+ padding=0,
238
+ ),
239
+ nn.ConvTranspose2d(
240
+ in_channels=features[1],
241
+ out_channels=features[1],
242
+ kernel_size=2,
243
+ stride=2,
244
+ padding=0,
245
+ bias=True,
246
+ dilation=1,
247
+ groups=1,
248
+ ),
249
+ )
250
+
251
+ pretrained.act_postprocess3 = nn.Sequential(
252
+ readout_oper[2],
253
+ Transpose(1, 2),
254
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
255
+ nn.Conv2d(
256
+ in_channels=vit_features,
257
+ out_channels=features[2],
258
+ kernel_size=1,
259
+ stride=1,
260
+ padding=0,
261
+ ),
262
+ )
263
+
264
+ pretrained.act_postprocess4 = nn.Sequential(
265
+ readout_oper[3],
266
+ Transpose(1, 2),
267
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
268
+ nn.Conv2d(
269
+ in_channels=vit_features,
270
+ out_channels=features[3],
271
+ kernel_size=1,
272
+ stride=1,
273
+ padding=0,
274
+ ),
275
+ nn.Conv2d(
276
+ in_channels=features[3],
277
+ out_channels=features[3],
278
+ kernel_size=3,
279
+ stride=2,
280
+ padding=1,
281
+ ),
282
+ )
283
+
284
+ pretrained.model.start_index = start_index
285
+ pretrained.model.patch_size = [16, 16]
286
+
287
+ # We inject this function into the VisionTransformer instances so that
288
+ # we can use it with interpolated position embeddings without modifying the library source.
289
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
290
+ pretrained.model._resize_pos_embed = types.MethodType(
291
+ _resize_pos_embed, pretrained.model
292
+ )
293
+
294
+ return pretrained
295
+
296
+
297
+ def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
298
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
299
+
300
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
301
+ return _make_vit_b16_backbone(
302
+ model,
303
+ features=[256, 512, 1024, 1024],
304
+ hooks=hooks,
305
+ vit_features=1024,
306
+ use_readout=use_readout,
307
+ )
308
+
309
+
310
+ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
311
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
312
+
313
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
314
+ return _make_vit_b16_backbone(
315
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
316
+ )
317
+
318
+
319
+ def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
320
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
321
+
322
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
323
+ return _make_vit_b16_backbone(
324
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
325
+ )
326
+
327
+
328
+ def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
329
+ model = timm.create_model(
330
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
331
+ )
332
+
333
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
334
+ return _make_vit_b16_backbone(
335
+ model,
336
+ features=[96, 192, 384, 768],
337
+ hooks=hooks,
338
+ use_readout=use_readout,
339
+ start_index=2,
340
+ )
341
+
342
+
343
+ def _make_vit_b_rn50_backbone(
344
+ model,
345
+ features=[256, 512, 768, 768],
346
+ size=[384, 384],
347
+ hooks=[0, 1, 8, 11],
348
+ vit_features=768,
349
+ use_vit_only=False,
350
+ use_readout="ignore",
351
+ start_index=1,
352
+ ):
353
+ pretrained = nn.Module()
354
+
355
+ pretrained.model = model
356
+
357
+ if use_vit_only == True:
358
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
359
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
360
+ else:
361
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
362
+ get_activation("1")
363
+ )
364
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
365
+ get_activation("2")
366
+ )
367
+
368
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
369
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
370
+
371
+ pretrained.activations = activations
372
+
373
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
374
+
375
+ if use_vit_only == True:
376
+ pretrained.act_postprocess1 = nn.Sequential(
377
+ readout_oper[0],
378
+ Transpose(1, 2),
379
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
380
+ nn.Conv2d(
381
+ in_channels=vit_features,
382
+ out_channels=features[0],
383
+ kernel_size=1,
384
+ stride=1,
385
+ padding=0,
386
+ ),
387
+ nn.ConvTranspose2d(
388
+ in_channels=features[0],
389
+ out_channels=features[0],
390
+ kernel_size=4,
391
+ stride=4,
392
+ padding=0,
393
+ bias=True,
394
+ dilation=1,
395
+ groups=1,
396
+ ),
397
+ )
398
+
399
+ pretrained.act_postprocess2 = nn.Sequential(
400
+ readout_oper[1],
401
+ Transpose(1, 2),
402
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
403
+ nn.Conv2d(
404
+ in_channels=vit_features,
405
+ out_channels=features[1],
406
+ kernel_size=1,
407
+ stride=1,
408
+ padding=0,
409
+ ),
410
+ nn.ConvTranspose2d(
411
+ in_channels=features[1],
412
+ out_channels=features[1],
413
+ kernel_size=2,
414
+ stride=2,
415
+ padding=0,
416
+ bias=True,
417
+ dilation=1,
418
+ groups=1,
419
+ ),
420
+ )
421
+ else:
422
+ pretrained.act_postprocess1 = nn.Sequential(
423
+ nn.Identity(), nn.Identity(), nn.Identity()
424
+ )
425
+ pretrained.act_postprocess2 = nn.Sequential(
426
+ nn.Identity(), nn.Identity(), nn.Identity()
427
+ )
428
+
429
+ pretrained.act_postprocess3 = nn.Sequential(
430
+ readout_oper[2],
431
+ Transpose(1, 2),
432
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
433
+ nn.Conv2d(
434
+ in_channels=vit_features,
435
+ out_channels=features[2],
436
+ kernel_size=1,
437
+ stride=1,
438
+ padding=0,
439
+ ),
440
+ )
441
+
442
+ pretrained.act_postprocess4 = nn.Sequential(
443
+ readout_oper[3],
444
+ Transpose(1, 2),
445
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
446
+ nn.Conv2d(
447
+ in_channels=vit_features,
448
+ out_channels=features[3],
449
+ kernel_size=1,
450
+ stride=1,
451
+ padding=0,
452
+ ),
453
+ nn.Conv2d(
454
+ in_channels=features[3],
455
+ out_channels=features[3],
456
+ kernel_size=3,
457
+ stride=2,
458
+ padding=1,
459
+ ),
460
+ )
461
+
462
+ pretrained.model.start_index = start_index
463
+ pretrained.model.patch_size = [16, 16]
464
+
465
+ # We inject this function into the VisionTransformer instances so that
466
+ # we can use it with interpolated position embeddings without modifying the library source.
467
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
468
+
469
+ # We inject this function into the VisionTransformer instances so that
470
+ # we can use it with interpolated position embeddings without modifying the library source.
471
+ pretrained.model._resize_pos_embed = types.MethodType(
472
+ _resize_pos_embed, pretrained.model
473
+ )
474
+
475
+ return pretrained
476
+
477
+
478
+ def _make_pretrained_vitb_rn50_384(
479
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
480
+ ):
481
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
482
+
483
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
484
+ return _make_vit_b_rn50_backbone(
485
+ model,
486
+ features=[256, 512, 768, 768],
487
+ size=[384, 384],
488
+ hooks=hooks,
489
+ use_vit_only=use_vit_only,
490
+ use_readout=use_readout,
491
+ )
src/flux/annotator/midas/utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for monoDepth."""
2
+ import sys
3
+ import re
4
+ import numpy as np
5
+ import cv2
6
+ import torch
7
+
8
+
9
+ def read_pfm(path):
10
+ """Read pfm file.
11
+
12
+ Args:
13
+ path (str): path to file
14
+
15
+ Returns:
16
+ tuple: (data, scale)
17
+ """
18
+ with open(path, "rb") as file:
19
+
20
+ color = None
21
+ width = None
22
+ height = None
23
+ scale = None
24
+ endian = None
25
+
26
+ header = file.readline().rstrip()
27
+ if header.decode("ascii") == "PF":
28
+ color = True
29
+ elif header.decode("ascii") == "Pf":
30
+ color = False
31
+ else:
32
+ raise Exception("Not a PFM file: " + path)
33
+
34
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35
+ if dim_match:
36
+ width, height = list(map(int, dim_match.groups()))
37
+ else:
38
+ raise Exception("Malformed PFM header.")
39
+
40
+ scale = float(file.readline().decode("ascii").rstrip())
41
+ if scale < 0:
42
+ # little-endian
43
+ endian = "<"
44
+ scale = -scale
45
+ else:
46
+ # big-endian
47
+ endian = ">"
48
+
49
+ data = np.fromfile(file, endian + "f")
50
+ shape = (height, width, 3) if color else (height, width)
51
+
52
+ data = np.reshape(data, shape)
53
+ data = np.flipud(data)
54
+
55
+ return data, scale
56
+
57
+
58
+ def write_pfm(path, image, scale=1):
59
+ """Write pfm file.
60
+
61
+ Args:
62
+ path (str): pathto file
63
+ image (array): data
64
+ scale (int, optional): Scale. Defaults to 1.
65
+ """
66
+
67
+ with open(path, "wb") as file:
68
+ color = None
69
+
70
+ if image.dtype.name != "float32":
71
+ raise Exception("Image dtype must be float32.")
72
+
73
+ image = np.flipud(image)
74
+
75
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
76
+ color = True
77
+ elif (
78
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79
+ ): # greyscale
80
+ color = False
81
+ else:
82
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83
+
84
+ file.write("PF\n" if color else "Pf\n".encode())
85
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86
+
87
+ endian = image.dtype.byteorder
88
+
89
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
90
+ scale = -scale
91
+
92
+ file.write("%f\n".encode() % scale)
93
+
94
+ image.tofile(file)
95
+
96
+
97
+ def read_image(path):
98
+ """Read image and output RGB image (0-1).
99
+
100
+ Args:
101
+ path (str): path to file
102
+
103
+ Returns:
104
+ array: RGB image (0-1)
105
+ """
106
+ img = cv2.imread(path)
107
+
108
+ if img.ndim == 2:
109
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110
+
111
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112
+
113
+ return img
114
+
115
+
116
+ def resize_image(img):
117
+ """Resize image and make it fit for network.
118
+
119
+ Args:
120
+ img (array): image
121
+
122
+ Returns:
123
+ tensor: data ready for network
124
+ """
125
+ height_orig = img.shape[0]
126
+ width_orig = img.shape[1]
127
+
128
+ if width_orig > height_orig:
129
+ scale = width_orig / 384
130
+ else:
131
+ scale = height_orig / 384
132
+
133
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135
+
136
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137
+
138
+ img_resized = (
139
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140
+ )
141
+ img_resized = img_resized.unsqueeze(0)
142
+
143
+ return img_resized
144
+
145
+
146
+ def resize_depth(depth, width, height):
147
+ """Resize depth map and bring to CPU (numpy).
148
+
149
+ Args:
150
+ depth (tensor): depth
151
+ width (int): image width
152
+ height (int): image height
153
+
154
+ Returns:
155
+ array: processed depth
156
+ """
157
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158
+
159
+ depth_resized = cv2.resize(
160
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161
+ )
162
+
163
+ return depth_resized
164
+
165
+ def write_depth(path, depth, bits=1):
166
+ """Write depth map to pfm and png file.
167
+
168
+ Args:
169
+ path (str): filepath without extension
170
+ depth (array): depth
171
+ """
172
+ write_pfm(path + ".pfm", depth.astype(np.float32))
173
+
174
+ depth_min = depth.min()
175
+ depth_max = depth.max()
176
+
177
+ max_val = (2**(8*bits))-1
178
+
179
+ if depth_max - depth_min > np.finfo("float").eps:
180
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
181
+ else:
182
+ out = np.zeros(depth.shape, dtype=depth.type)
183
+
184
+ if bits == 1:
185
+ cv2.imwrite(path + ".png", out.astype("uint8"))
186
+ elif bits == 2:
187
+ cv2.imwrite(path + ".png", out.astype("uint16"))
188
+
189
+ return
src/flux/annotator/mlsd/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "{}"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2021-present NAVER Corp.
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
src/flux/annotator/mlsd/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MLSD Line Detection
2
+ # From https://github.com/navervision/mlsd
3
+ # Apache-2.0 license
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import os
9
+
10
+ from einops import rearrange
11
+ from huggingface_hub import hf_hub_download
12
+ from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny
13
+ from .models.mbv2_mlsd_large import MobileV2_MLSD_Large
14
+ from .utils import pred_lines
15
+
16
+ from ...annotator.util import annotator_ckpts_path
17
+
18
+
19
+ class MLSDdetector:
20
+ def __init__(self):
21
+ model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth")
22
+ if not os.path.exists(model_path):
23
+ model_path = hf_hub_download("lllyasviel/Annotators", "mlsd_large_512_fp32.pth")
24
+ model = MobileV2_MLSD_Large()
25
+ model.load_state_dict(torch.load(model_path), strict=True)
26
+ self.model = model.cuda().eval()
27
+
28
+ def __call__(self, input_image, thr_v, thr_d):
29
+ assert input_image.ndim == 3
30
+ img = input_image
31
+ img_output = np.zeros_like(img)
32
+ try:
33
+ with torch.no_grad():
34
+ lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d)
35
+ for line in lines:
36
+ x_start, y_start, x_end, y_end = [int(val) for val in line]
37
+ cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1)
38
+ except Exception as e:
39
+ pass
40
+ return img_output[:, :, 0]
src/flux/annotator/mlsd/models/mbv2_mlsd_large.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.utils.model_zoo as model_zoo
6
+ from torch.nn import functional as F
7
+
8
+
9
+ class BlockTypeA(nn.Module):
10
+ def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
11
+ super(BlockTypeA, self).__init__()
12
+ self.conv1 = nn.Sequential(
13
+ nn.Conv2d(in_c2, out_c2, kernel_size=1),
14
+ nn.BatchNorm2d(out_c2),
15
+ nn.ReLU(inplace=True)
16
+ )
17
+ self.conv2 = nn.Sequential(
18
+ nn.Conv2d(in_c1, out_c1, kernel_size=1),
19
+ nn.BatchNorm2d(out_c1),
20
+ nn.ReLU(inplace=True)
21
+ )
22
+ self.upscale = upscale
23
+
24
+ def forward(self, a, b):
25
+ b = self.conv1(b)
26
+ a = self.conv2(a)
27
+ if self.upscale:
28
+ b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
29
+ return torch.cat((a, b), dim=1)
30
+
31
+
32
+ class BlockTypeB(nn.Module):
33
+ def __init__(self, in_c, out_c):
34
+ super(BlockTypeB, self).__init__()
35
+ self.conv1 = nn.Sequential(
36
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
37
+ nn.BatchNorm2d(in_c),
38
+ nn.ReLU()
39
+ )
40
+ self.conv2 = nn.Sequential(
41
+ nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
42
+ nn.BatchNorm2d(out_c),
43
+ nn.ReLU()
44
+ )
45
+
46
+ def forward(self, x):
47
+ x = self.conv1(x) + x
48
+ x = self.conv2(x)
49
+ return x
50
+
51
+ class BlockTypeC(nn.Module):
52
+ def __init__(self, in_c, out_c):
53
+ super(BlockTypeC, self).__init__()
54
+ self.conv1 = nn.Sequential(
55
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
56
+ nn.BatchNorm2d(in_c),
57
+ nn.ReLU()
58
+ )
59
+ self.conv2 = nn.Sequential(
60
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
61
+ nn.BatchNorm2d(in_c),
62
+ nn.ReLU()
63
+ )
64
+ self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
65
+
66
+ def forward(self, x):
67
+ x = self.conv1(x)
68
+ x = self.conv2(x)
69
+ x = self.conv3(x)
70
+ return x
71
+
72
+ def _make_divisible(v, divisor, min_value=None):
73
+ """
74
+ This function is taken from the original tf repo.
75
+ It ensures that all layers have a channel number that is divisible by 8
76
+ It can be seen here:
77
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
78
+ :param v:
79
+ :param divisor:
80
+ :param min_value:
81
+ :return:
82
+ """
83
+ if min_value is None:
84
+ min_value = divisor
85
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
86
+ # Make sure that round down does not go down by more than 10%.
87
+ if new_v < 0.9 * v:
88
+ new_v += divisor
89
+ return new_v
90
+
91
+
92
+ class ConvBNReLU(nn.Sequential):
93
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
94
+ self.channel_pad = out_planes - in_planes
95
+ self.stride = stride
96
+ #padding = (kernel_size - 1) // 2
97
+
98
+ # TFLite uses slightly different padding than PyTorch
99
+ if stride == 2:
100
+ padding = 0
101
+ else:
102
+ padding = (kernel_size - 1) // 2
103
+
104
+ super(ConvBNReLU, self).__init__(
105
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
106
+ nn.BatchNorm2d(out_planes),
107
+ nn.ReLU6(inplace=True)
108
+ )
109
+ self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
110
+
111
+
112
+ def forward(self, x):
113
+ # TFLite uses different padding
114
+ if self.stride == 2:
115
+ x = F.pad(x, (0, 1, 0, 1), "constant", 0)
116
+ #print(x.shape)
117
+
118
+ for module in self:
119
+ if not isinstance(module, nn.MaxPool2d):
120
+ x = module(x)
121
+ return x
122
+
123
+
124
+ class InvertedResidual(nn.Module):
125
+ def __init__(self, inp, oup, stride, expand_ratio):
126
+ super(InvertedResidual, self).__init__()
127
+ self.stride = stride
128
+ assert stride in [1, 2]
129
+
130
+ hidden_dim = int(round(inp * expand_ratio))
131
+ self.use_res_connect = self.stride == 1 and inp == oup
132
+
133
+ layers = []
134
+ if expand_ratio != 1:
135
+ # pw
136
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
137
+ layers.extend([
138
+ # dw
139
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
140
+ # pw-linear
141
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
142
+ nn.BatchNorm2d(oup),
143
+ ])
144
+ self.conv = nn.Sequential(*layers)
145
+
146
+ def forward(self, x):
147
+ if self.use_res_connect:
148
+ return x + self.conv(x)
149
+ else:
150
+ return self.conv(x)
151
+
152
+
153
+ class MobileNetV2(nn.Module):
154
+ def __init__(self, pretrained=True):
155
+ """
156
+ MobileNet V2 main class
157
+ Args:
158
+ num_classes (int): Number of classes
159
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
160
+ inverted_residual_setting: Network structure
161
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
162
+ Set to 1 to turn off rounding
163
+ block: Module specifying inverted residual building block for mobilenet
164
+ """
165
+ super(MobileNetV2, self).__init__()
166
+
167
+ block = InvertedResidual
168
+ input_channel = 32
169
+ last_channel = 1280
170
+ width_mult = 1.0
171
+ round_nearest = 8
172
+
173
+ inverted_residual_setting = [
174
+ # t, c, n, s
175
+ [1, 16, 1, 1],
176
+ [6, 24, 2, 2],
177
+ [6, 32, 3, 2],
178
+ [6, 64, 4, 2],
179
+ [6, 96, 3, 1],
180
+ #[6, 160, 3, 2],
181
+ #[6, 320, 1, 1],
182
+ ]
183
+
184
+ # only check the first element, assuming user knows t,c,n,s are required
185
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
186
+ raise ValueError("inverted_residual_setting should be non-empty "
187
+ "or a 4-element list, got {}".format(inverted_residual_setting))
188
+
189
+ # building first layer
190
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
191
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
192
+ features = [ConvBNReLU(4, input_channel, stride=2)]
193
+ # building inverted residual blocks
194
+ for t, c, n, s in inverted_residual_setting:
195
+ output_channel = _make_divisible(c * width_mult, round_nearest)
196
+ for i in range(n):
197
+ stride = s if i == 0 else 1
198
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t))
199
+ input_channel = output_channel
200
+
201
+ self.features = nn.Sequential(*features)
202
+ self.fpn_selected = [1, 3, 6, 10, 13]
203
+ # weight initialization
204
+ for m in self.modules():
205
+ if isinstance(m, nn.Conv2d):
206
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
207
+ if m.bias is not None:
208
+ nn.init.zeros_(m.bias)
209
+ elif isinstance(m, nn.BatchNorm2d):
210
+ nn.init.ones_(m.weight)
211
+ nn.init.zeros_(m.bias)
212
+ elif isinstance(m, nn.Linear):
213
+ nn.init.normal_(m.weight, 0, 0.01)
214
+ nn.init.zeros_(m.bias)
215
+ if pretrained:
216
+ self._load_pretrained_model()
217
+
218
+ def _forward_impl(self, x):
219
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
220
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
221
+ fpn_features = []
222
+ for i, f in enumerate(self.features):
223
+ if i > self.fpn_selected[-1]:
224
+ break
225
+ x = f(x)
226
+ if i in self.fpn_selected:
227
+ fpn_features.append(x)
228
+
229
+ c1, c2, c3, c4, c5 = fpn_features
230
+ return c1, c2, c3, c4, c5
231
+
232
+
233
+ def forward(self, x):
234
+ return self._forward_impl(x)
235
+
236
+ def _load_pretrained_model(self):
237
+ pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
238
+ model_dict = {}
239
+ state_dict = self.state_dict()
240
+ for k, v in pretrain_dict.items():
241
+ if k in state_dict:
242
+ model_dict[k] = v
243
+ state_dict.update(model_dict)
244
+ self.load_state_dict(state_dict)
245
+
246
+
247
+ class MobileV2_MLSD_Large(nn.Module):
248
+ def __init__(self):
249
+ super(MobileV2_MLSD_Large, self).__init__()
250
+
251
+ self.backbone = MobileNetV2(pretrained=False)
252
+ ## A, B
253
+ self.block15 = BlockTypeA(in_c1= 64, in_c2= 96,
254
+ out_c1= 64, out_c2=64,
255
+ upscale=False)
256
+ self.block16 = BlockTypeB(128, 64)
257
+
258
+ ## A, B
259
+ self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64,
260
+ out_c1= 64, out_c2= 64)
261
+ self.block18 = BlockTypeB(128, 64)
262
+
263
+ ## A, B
264
+ self.block19 = BlockTypeA(in_c1=24, in_c2=64,
265
+ out_c1=64, out_c2=64)
266
+ self.block20 = BlockTypeB(128, 64)
267
+
268
+ ## A, B, C
269
+ self.block21 = BlockTypeA(in_c1=16, in_c2=64,
270
+ out_c1=64, out_c2=64)
271
+ self.block22 = BlockTypeB(128, 64)
272
+
273
+ self.block23 = BlockTypeC(64, 16)
274
+
275
+ def forward(self, x):
276
+ c1, c2, c3, c4, c5 = self.backbone(x)
277
+
278
+ x = self.block15(c4, c5)
279
+ x = self.block16(x)
280
+
281
+ x = self.block17(c3, x)
282
+ x = self.block18(x)
283
+
284
+ x = self.block19(c2, x)
285
+ x = self.block20(x)
286
+
287
+ x = self.block21(c1, x)
288
+ x = self.block22(x)
289
+ x = self.block23(x)
290
+ x = x[:, 7:, :, :]
291
+
292
+ return x
src/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.utils.model_zoo as model_zoo
6
+ from torch.nn import functional as F
7
+
8
+
9
+ class BlockTypeA(nn.Module):
10
+ def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
11
+ super(BlockTypeA, self).__init__()
12
+ self.conv1 = nn.Sequential(
13
+ nn.Conv2d(in_c2, out_c2, kernel_size=1),
14
+ nn.BatchNorm2d(out_c2),
15
+ nn.ReLU(inplace=True)
16
+ )
17
+ self.conv2 = nn.Sequential(
18
+ nn.Conv2d(in_c1, out_c1, kernel_size=1),
19
+ nn.BatchNorm2d(out_c1),
20
+ nn.ReLU(inplace=True)
21
+ )
22
+ self.upscale = upscale
23
+
24
+ def forward(self, a, b):
25
+ b = self.conv1(b)
26
+ a = self.conv2(a)
27
+ b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
28
+ return torch.cat((a, b), dim=1)
29
+
30
+
31
+ class BlockTypeB(nn.Module):
32
+ def __init__(self, in_c, out_c):
33
+ super(BlockTypeB, self).__init__()
34
+ self.conv1 = nn.Sequential(
35
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
36
+ nn.BatchNorm2d(in_c),
37
+ nn.ReLU()
38
+ )
39
+ self.conv2 = nn.Sequential(
40
+ nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
41
+ nn.BatchNorm2d(out_c),
42
+ nn.ReLU()
43
+ )
44
+
45
+ def forward(self, x):
46
+ x = self.conv1(x) + x
47
+ x = self.conv2(x)
48
+ return x
49
+
50
+ class BlockTypeC(nn.Module):
51
+ def __init__(self, in_c, out_c):
52
+ super(BlockTypeC, self).__init__()
53
+ self.conv1 = nn.Sequential(
54
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
55
+ nn.BatchNorm2d(in_c),
56
+ nn.ReLU()
57
+ )
58
+ self.conv2 = nn.Sequential(
59
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
60
+ nn.BatchNorm2d(in_c),
61
+ nn.ReLU()
62
+ )
63
+ self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
64
+
65
+ def forward(self, x):
66
+ x = self.conv1(x)
67
+ x = self.conv2(x)
68
+ x = self.conv3(x)
69
+ return x
70
+
71
+ def _make_divisible(v, divisor, min_value=None):
72
+ """
73
+ This function is taken from the original tf repo.
74
+ It ensures that all layers have a channel number that is divisible by 8
75
+ It can be seen here:
76
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
77
+ :param v:
78
+ :param divisor:
79
+ :param min_value:
80
+ :return:
81
+ """
82
+ if min_value is None:
83
+ min_value = divisor
84
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
85
+ # Make sure that round down does not go down by more than 10%.
86
+ if new_v < 0.9 * v:
87
+ new_v += divisor
88
+ return new_v
89
+
90
+
91
+ class ConvBNReLU(nn.Sequential):
92
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
93
+ self.channel_pad = out_planes - in_planes
94
+ self.stride = stride
95
+ #padding = (kernel_size - 1) // 2
96
+
97
+ # TFLite uses slightly different padding than PyTorch
98
+ if stride == 2:
99
+ padding = 0
100
+ else:
101
+ padding = (kernel_size - 1) // 2
102
+
103
+ super(ConvBNReLU, self).__init__(
104
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
105
+ nn.BatchNorm2d(out_planes),
106
+ nn.ReLU6(inplace=True)
107
+ )
108
+ self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
109
+
110
+
111
+ def forward(self, x):
112
+ # TFLite uses different padding
113
+ if self.stride == 2:
114
+ x = F.pad(x, (0, 1, 0, 1), "constant", 0)
115
+ #print(x.shape)
116
+
117
+ for module in self:
118
+ if not isinstance(module, nn.MaxPool2d):
119
+ x = module(x)
120
+ return x
121
+
122
+
123
+ class InvertedResidual(nn.Module):
124
+ def __init__(self, inp, oup, stride, expand_ratio):
125
+ super(InvertedResidual, self).__init__()
126
+ self.stride = stride
127
+ assert stride in [1, 2]
128
+
129
+ hidden_dim = int(round(inp * expand_ratio))
130
+ self.use_res_connect = self.stride == 1 and inp == oup
131
+
132
+ layers = []
133
+ if expand_ratio != 1:
134
+ # pw
135
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
136
+ layers.extend([
137
+ # dw
138
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
139
+ # pw-linear
140
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
141
+ nn.BatchNorm2d(oup),
142
+ ])
143
+ self.conv = nn.Sequential(*layers)
144
+
145
+ def forward(self, x):
146
+ if self.use_res_connect:
147
+ return x + self.conv(x)
148
+ else:
149
+ return self.conv(x)
150
+
151
+
152
+ class MobileNetV2(nn.Module):
153
+ def __init__(self, pretrained=True):
154
+ """
155
+ MobileNet V2 main class
156
+ Args:
157
+ num_classes (int): Number of classes
158
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
159
+ inverted_residual_setting: Network structure
160
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
161
+ Set to 1 to turn off rounding
162
+ block: Module specifying inverted residual building block for mobilenet
163
+ """
164
+ super(MobileNetV2, self).__init__()
165
+
166
+ block = InvertedResidual
167
+ input_channel = 32
168
+ last_channel = 1280
169
+ width_mult = 1.0
170
+ round_nearest = 8
171
+
172
+ inverted_residual_setting = [
173
+ # t, c, n, s
174
+ [1, 16, 1, 1],
175
+ [6, 24, 2, 2],
176
+ [6, 32, 3, 2],
177
+ [6, 64, 4, 2],
178
+ #[6, 96, 3, 1],
179
+ #[6, 160, 3, 2],
180
+ #[6, 320, 1, 1],
181
+ ]
182
+
183
+ # only check the first element, assuming user knows t,c,n,s are required
184
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
185
+ raise ValueError("inverted_residual_setting should be non-empty "
186
+ "or a 4-element list, got {}".format(inverted_residual_setting))
187
+
188
+ # building first layer
189
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
190
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
191
+ features = [ConvBNReLU(4, input_channel, stride=2)]
192
+ # building inverted residual blocks
193
+ for t, c, n, s in inverted_residual_setting:
194
+ output_channel = _make_divisible(c * width_mult, round_nearest)
195
+ for i in range(n):
196
+ stride = s if i == 0 else 1
197
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t))
198
+ input_channel = output_channel
199
+ self.features = nn.Sequential(*features)
200
+
201
+ self.fpn_selected = [3, 6, 10]
202
+ # weight initialization
203
+ for m in self.modules():
204
+ if isinstance(m, nn.Conv2d):
205
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
206
+ if m.bias is not None:
207
+ nn.init.zeros_(m.bias)
208
+ elif isinstance(m, nn.BatchNorm2d):
209
+ nn.init.ones_(m.weight)
210
+ nn.init.zeros_(m.bias)
211
+ elif isinstance(m, nn.Linear):
212
+ nn.init.normal_(m.weight, 0, 0.01)
213
+ nn.init.zeros_(m.bias)
214
+
215
+ #if pretrained:
216
+ # self._load_pretrained_model()
217
+
218
+ def _forward_impl(self, x):
219
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
220
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
221
+ fpn_features = []
222
+ for i, f in enumerate(self.features):
223
+ if i > self.fpn_selected[-1]:
224
+ break
225
+ x = f(x)
226
+ if i in self.fpn_selected:
227
+ fpn_features.append(x)
228
+
229
+ c2, c3, c4 = fpn_features
230
+ return c2, c3, c4
231
+
232
+
233
+ def forward(self, x):
234
+ return self._forward_impl(x)
235
+
236
+ def _load_pretrained_model(self):
237
+ pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
238
+ model_dict = {}
239
+ state_dict = self.state_dict()
240
+ for k, v in pretrain_dict.items():
241
+ if k in state_dict:
242
+ model_dict[k] = v
243
+ state_dict.update(model_dict)
244
+ self.load_state_dict(state_dict)
245
+
246
+
247
+ class MobileV2_MLSD_Tiny(nn.Module):
248
+ def __init__(self):
249
+ super(MobileV2_MLSD_Tiny, self).__init__()
250
+
251
+ self.backbone = MobileNetV2(pretrained=True)
252
+
253
+ self.block12 = BlockTypeA(in_c1= 32, in_c2= 64,
254
+ out_c1= 64, out_c2=64)
255
+ self.block13 = BlockTypeB(128, 64)
256
+
257
+ self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64,
258
+ out_c1= 32, out_c2= 32)
259
+ self.block15 = BlockTypeB(64, 64)
260
+
261
+ self.block16 = BlockTypeC(64, 16)
262
+
263
+ def forward(self, x):
264
+ c2, c3, c4 = self.backbone(x)
265
+
266
+ x = self.block12(c3, c4)
267
+ x = self.block13(x)
268
+ x = self.block14(c2, x)
269
+ x = self.block15(x)
270
+ x = self.block16(x)
271
+ x = x[:, 7:, :, :]
272
+ #print(x.shape)
273
+ x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True)
274
+
275
+ return x
src/flux/annotator/mlsd/utils.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ modified by lihaoweicv
3
+ pytorch version
4
+ '''
5
+
6
+ '''
7
+ M-LSD
8
+ Copyright 2021-present NAVER Corp.
9
+ Apache License v2.0
10
+ '''
11
+
12
+ import os
13
+ import numpy as np
14
+ import cv2
15
+ import torch
16
+ from torch.nn import functional as F
17
+
18
+
19
+ def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
20
+ '''
21
+ tpMap:
22
+ center: tpMap[1, 0, :, :]
23
+ displacement: tpMap[1, 1:5, :, :]
24
+ '''
25
+ b, c, h, w = tpMap.shape
26
+ assert b==1, 'only support bsize==1'
27
+ displacement = tpMap[:, 1:5, :, :][0]
28
+ center = tpMap[:, 0, :, :]
29
+ heat = torch.sigmoid(center)
30
+ hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2)
31
+ keep = (hmax == heat).float()
32
+ heat = heat * keep
33
+ heat = heat.reshape(-1, )
34
+
35
+ scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
36
+ yy = torch.floor_divide(indices, w).unsqueeze(-1)
37
+ xx = torch.fmod(indices, w).unsqueeze(-1)
38
+ ptss = torch.cat((yy, xx),dim=-1)
39
+
40
+ ptss = ptss.detach().cpu().numpy()
41
+ scores = scores.detach().cpu().numpy()
42
+ displacement = displacement.detach().cpu().numpy()
43
+ displacement = displacement.transpose((1,2,0))
44
+ return ptss, scores, displacement
45
+
46
+
47
+ def pred_lines(image, model,
48
+ input_shape=[512, 512],
49
+ score_thr=0.10,
50
+ dist_thr=20.0):
51
+ h, w, _ = image.shape
52
+ h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
53
+
54
+ resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
55
+ np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
56
+
57
+ resized_image = resized_image.transpose((2,0,1))
58
+ batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
59
+ batch_image = (batch_image / 127.5) - 1.0
60
+
61
+ batch_image = torch.from_numpy(batch_image).float().to("cuda:4")
62
+ outputs = model(batch_image)
63
+ pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
64
+ start = vmap[:, :, :2]
65
+ end = vmap[:, :, 2:]
66
+ dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
67
+
68
+ segments_list = []
69
+ for center, score in zip(pts, pts_score):
70
+ y, x = center
71
+ distance = dist_map[y, x]
72
+ if score > score_thr and distance > dist_thr:
73
+ disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
74
+ x_start = x + disp_x_start
75
+ y_start = y + disp_y_start
76
+ x_end = x + disp_x_end
77
+ y_end = y + disp_y_end
78
+ segments_list.append([x_start, y_start, x_end, y_end])
79
+
80
+ lines = 2 * np.array(segments_list) # 256 > 512
81
+ lines[:, 0] = lines[:, 0] * w_ratio
82
+ lines[:, 1] = lines[:, 1] * h_ratio
83
+ lines[:, 2] = lines[:, 2] * w_ratio
84
+ lines[:, 3] = lines[:, 3] * h_ratio
85
+
86
+ return lines
87
+
88
+
89
+ def pred_squares(image,
90
+ model,
91
+ input_shape=[512, 512],
92
+ params={'score': 0.06,
93
+ 'outside_ratio': 0.28,
94
+ 'inside_ratio': 0.45,
95
+ 'w_overlap': 0.0,
96
+ 'w_degree': 1.95,
97
+ 'w_length': 0.0,
98
+ 'w_area': 1.86,
99
+ 'w_center': 0.14}):
100
+ '''
101
+ shape = [height, width]
102
+ '''
103
+ h, w, _ = image.shape
104
+ original_shape = [h, w]
105
+
106
+ resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
107
+ np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
108
+ resized_image = resized_image.transpose((2, 0, 1))
109
+ batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
110
+ batch_image = (batch_image / 127.5) - 1.0
111
+
112
+ batch_image = torch.from_numpy(batch_image).float().cuda()
113
+ outputs = model(batch_image)
114
+
115
+ pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
116
+ start = vmap[:, :, :2] # (x, y)
117
+ end = vmap[:, :, 2:] # (x, y)
118
+ dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
119
+
120
+ junc_list = []
121
+ segments_list = []
122
+ for junc, score in zip(pts, pts_score):
123
+ y, x = junc
124
+ distance = dist_map[y, x]
125
+ if score > params['score'] and distance > 20.0:
126
+ junc_list.append([x, y])
127
+ disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
128
+ d_arrow = 1.0
129
+ x_start = x + d_arrow * disp_x_start
130
+ y_start = y + d_arrow * disp_y_start
131
+ x_end = x + d_arrow * disp_x_end
132
+ y_end = y + d_arrow * disp_y_end
133
+ segments_list.append([x_start, y_start, x_end, y_end])
134
+
135
+ segments = np.array(segments_list)
136
+
137
+ ####### post processing for squares
138
+ # 1. get unique lines
139
+ point = np.array([[0, 0]])
140
+ point = point[0]
141
+ start = segments[:, :2]
142
+ end = segments[:, 2:]
143
+ diff = start - end
144
+ a = diff[:, 1]
145
+ b = -diff[:, 0]
146
+ c = a * start[:, 0] + b * start[:, 1]
147
+
148
+ d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10)
149
+ theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi
150
+ theta[theta < 0.0] += 180
151
+ hough = np.concatenate([d[:, None], theta[:, None]], axis=-1)
152
+
153
+ d_quant = 1
154
+ theta_quant = 2
155
+ hough[:, 0] //= d_quant
156
+ hough[:, 1] //= theta_quant
157
+ _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True)
158
+
159
+ acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32')
160
+ idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1
161
+ yx_indices = hough[indices, :].astype('int32')
162
+ acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts
163
+ idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices
164
+
165
+ acc_map_np = acc_map
166
+ # acc_map = acc_map[None, :, :, None]
167
+ #
168
+ # ### fast suppression using tensorflow op
169
+ # acc_map = tf.constant(acc_map, dtype=tf.float32)
170
+ # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map)
171
+ # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32)
172
+ # flatten_acc_map = tf.reshape(acc_map, [1, -1])
173
+ # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts))
174
+ # _, h, w, _ = acc_map.shape
175
+ # y = tf.expand_dims(topk_indices // w, axis=-1)
176
+ # x = tf.expand_dims(topk_indices % w, axis=-1)
177
+ # yx = tf.concat([y, x], axis=-1)
178
+
179
+ ### fast suppression using pytorch op
180
+ acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0)
181
+ _,_, h, w = acc_map.shape
182
+ max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2)
183
+ acc_map = acc_map * ( (acc_map == max_acc_map).float() )
184
+ flatten_acc_map = acc_map.reshape([-1, ])
185
+
186
+ scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True)
187
+ yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1)
188
+ xx = torch.fmod(indices, w).unsqueeze(-1)
189
+ yx = torch.cat((yy, xx), dim=-1)
190
+
191
+ yx = yx.detach().cpu().numpy()
192
+
193
+ topk_values = scores.detach().cpu().numpy()
194
+ indices = idx_map[yx[:, 0], yx[:, 1]]
195
+ basis = 5 // 2
196
+
197
+ merged_segments = []
198
+ for yx_pt, max_indice, value in zip(yx, indices, topk_values):
199
+ y, x = yx_pt
200
+ if max_indice == -1 or value == 0:
201
+ continue
202
+ segment_list = []
203
+ for y_offset in range(-basis, basis + 1):
204
+ for x_offset in range(-basis, basis + 1):
205
+ indice = idx_map[y + y_offset, x + x_offset]
206
+ cnt = int(acc_map_np[y + y_offset, x + x_offset])
207
+ if indice != -1:
208
+ segment_list.append(segments[indice])
209
+ if cnt > 1:
210
+ check_cnt = 1
211
+ current_hough = hough[indice]
212
+ for new_indice, new_hough in enumerate(hough):
213
+ if (current_hough == new_hough).all() and indice != new_indice:
214
+ segment_list.append(segments[new_indice])
215
+ check_cnt += 1
216
+ if check_cnt == cnt:
217
+ break
218
+ group_segments = np.array(segment_list).reshape([-1, 2])
219
+ sorted_group_segments = np.sort(group_segments, axis=0)
220
+ x_min, y_min = sorted_group_segments[0, :]
221
+ x_max, y_max = sorted_group_segments[-1, :]
222
+
223
+ deg = theta[max_indice]
224
+ if deg >= 90:
225
+ merged_segments.append([x_min, y_max, x_max, y_min])
226
+ else:
227
+ merged_segments.append([x_min, y_min, x_max, y_max])
228
+
229
+ # 2. get intersections
230
+ new_segments = np.array(merged_segments) # (x1, y1, x2, y2)
231
+ start = new_segments[:, :2] # (x1, y1)
232
+ end = new_segments[:, 2:] # (x2, y2)
233
+ new_centers = (start + end) / 2.0
234
+ diff = start - end
235
+ dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1))
236
+
237
+ # ax + by = c
238
+ a = diff[:, 1]
239
+ b = -diff[:, 0]
240
+ c = a * start[:, 0] + b * start[:, 1]
241
+ pre_det = a[:, None] * b[None, :]
242
+ det = pre_det - np.transpose(pre_det)
243
+
244
+ pre_inter_y = a[:, None] * c[None, :]
245
+ inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10)
246
+ pre_inter_x = c[:, None] * b[None, :]
247
+ inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10)
248
+ inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32')
249
+
250
+ # 3. get corner information
251
+ # 3.1 get distance
252
+ '''
253
+ dist_segments:
254
+ | dist(0), dist(1), dist(2), ...|
255
+ dist_inter_to_segment1:
256
+ | dist(inter,0), dist(inter,0), dist(inter,0), ... |
257
+ | dist(inter,1), dist(inter,1), dist(inter,1), ... |
258
+ ...
259
+ dist_inter_to_semgnet2:
260
+ | dist(inter,0), dist(inter,1), dist(inter,2), ... |
261
+ | dist(inter,0), dist(inter,1), dist(inter,2), ... |
262
+ ...
263
+ '''
264
+
265
+ dist_inter_to_segment1_start = np.sqrt(
266
+ np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
267
+ dist_inter_to_segment1_end = np.sqrt(
268
+ np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
269
+ dist_inter_to_segment2_start = np.sqrt(
270
+ np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
271
+ dist_inter_to_segment2_end = np.sqrt(
272
+ np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
273
+
274
+ # sort ascending
275
+ dist_inter_to_segment1 = np.sort(
276
+ np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1),
277
+ axis=-1) # [n_batch, n_batch, 2]
278
+ dist_inter_to_segment2 = np.sort(
279
+ np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1),
280
+ axis=-1) # [n_batch, n_batch, 2]
281
+
282
+ # 3.2 get degree
283
+ inter_to_start = new_centers[:, None, :] - inter_pts
284
+ deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi
285
+ deg_inter_to_start[deg_inter_to_start < 0.0] += 360
286
+ inter_to_end = new_centers[None, :, :] - inter_pts
287
+ deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi
288
+ deg_inter_to_end[deg_inter_to_end < 0.0] += 360
289
+
290
+ '''
291
+ B -- G
292
+ | |
293
+ C -- R
294
+ B : blue / G: green / C: cyan / R: red
295
+
296
+ 0 -- 1
297
+ | |
298
+ 3 -- 2
299
+ '''
300
+ # rename variables
301
+ deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end
302
+ # sort deg ascending
303
+ deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1)
304
+
305
+ deg_diff_map = np.abs(deg1_map - deg2_map)
306
+ # we only consider the smallest degree of intersect
307
+ deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180]
308
+
309
+ # define available degree range
310
+ deg_range = [60, 120]
311
+
312
+ corner_dict = {corner_info: [] for corner_info in range(4)}
313
+ inter_points = []
314
+ for i in range(inter_pts.shape[0]):
315
+ for j in range(i + 1, inter_pts.shape[1]):
316
+ # i, j > line index, always i < j
317
+ x, y = inter_pts[i, j, :]
318
+ deg1, deg2 = deg_sort[i, j, :]
319
+ deg_diff = deg_diff_map[i, j]
320
+
321
+ check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1]
322
+
323
+ outside_ratio = params['outside_ratio'] # over ratio >>> drop it!
324
+ inside_ratio = params['inside_ratio'] # over ratio >>> drop it!
325
+ check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \
326
+ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \
327
+ (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \
328
+ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \
329
+ ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \
330
+ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \
331
+ (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \
332
+ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio))
333
+
334
+ if check_degree and check_distance:
335
+ corner_info = None
336
+
337
+ if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \
338
+ (deg2 >= 315 and deg1 >= 45 and deg1 <= 120):
339
+ corner_info, color_info = 0, 'blue'
340
+ elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225):
341
+ corner_info, color_info = 1, 'green'
342
+ elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315):
343
+ corner_info, color_info = 2, 'black'
344
+ elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \
345
+ (deg2 >= 315 and deg1 >= 225 and deg1 <= 315):
346
+ corner_info, color_info = 3, 'cyan'
347
+ else:
348
+ corner_info, color_info = 4, 'red' # we don't use it
349
+ continue
350
+
351
+ corner_dict[corner_info].append([x, y, i, j])
352
+ inter_points.append([x, y])
353
+
354
+ square_list = []
355
+ connect_list = []
356
+ segments_list = []
357
+ for corner0 in corner_dict[0]:
358
+ for corner1 in corner_dict[1]:
359
+ connect01 = False
360
+ for corner0_line in corner0[2:]:
361
+ if corner0_line in corner1[2:]:
362
+ connect01 = True
363
+ break
364
+ if connect01:
365
+ for corner2 in corner_dict[2]:
366
+ connect12 = False
367
+ for corner1_line in corner1[2:]:
368
+ if corner1_line in corner2[2:]:
369
+ connect12 = True
370
+ break
371
+ if connect12:
372
+ for corner3 in corner_dict[3]:
373
+ connect23 = False
374
+ for corner2_line in corner2[2:]:
375
+ if corner2_line in corner3[2:]:
376
+ connect23 = True
377
+ break
378
+ if connect23:
379
+ for corner3_line in corner3[2:]:
380
+ if corner3_line in corner0[2:]:
381
+ # SQUARE!!!
382
+ '''
383
+ 0 -- 1
384
+ | |
385
+ 3 -- 2
386
+ square_list:
387
+ order: 0 > 1 > 2 > 3
388
+ | x0, y0, x1, y1, x2, y2, x3, y3 |
389
+ | x0, y0, x1, y1, x2, y2, x3, y3 |
390
+ ...
391
+ connect_list:
392
+ order: 01 > 12 > 23 > 30
393
+ | line_idx01, line_idx12, line_idx23, line_idx30 |
394
+ | line_idx01, line_idx12, line_idx23, line_idx30 |
395
+ ...
396
+ segments_list:
397
+ order: 0 > 1 > 2 > 3
398
+ | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
399
+ | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
400
+ ...
401
+ '''
402
+ square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2])
403
+ connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line])
404
+ segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:])
405
+
406
+ def check_outside_inside(segments_info, connect_idx):
407
+ # return 'outside or inside', min distance, cover_param, peri_param
408
+ if connect_idx == segments_info[0]:
409
+ check_dist_mat = dist_inter_to_segment1
410
+ else:
411
+ check_dist_mat = dist_inter_to_segment2
412
+
413
+ i, j = segments_info
414
+ min_dist, max_dist = check_dist_mat[i, j, :]
415
+ connect_dist = dist_segments[connect_idx]
416
+ if max_dist > connect_dist:
417
+ return 'outside', min_dist, 0, 1
418
+ else:
419
+ return 'inside', min_dist, -1, -1
420
+
421
+ top_square = None
422
+
423
+ try:
424
+ map_size = input_shape[0] / 2
425
+ squares = np.array(square_list).reshape([-1, 4, 2])
426
+ score_array = []
427
+ connect_array = np.array(connect_list)
428
+ segments_array = np.array(segments_list).reshape([-1, 4, 2])
429
+
430
+ # get degree of corners:
431
+ squares_rollup = np.roll(squares, 1, axis=1)
432
+ squares_rolldown = np.roll(squares, -1, axis=1)
433
+ vec1 = squares_rollup - squares
434
+ normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10)
435
+ vec2 = squares_rolldown - squares
436
+ normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10)
437
+ inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4]
438
+ squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4]
439
+
440
+ # get square score
441
+ overlap_scores = []
442
+ degree_scores = []
443
+ length_scores = []
444
+
445
+ for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree):
446
+ '''
447
+ 0 -- 1
448
+ | |
449
+ 3 -- 2
450
+
451
+ # segments: [4, 2]
452
+ # connects: [4]
453
+ '''
454
+
455
+ ###################################### OVERLAP SCORES
456
+ cover = 0
457
+ perimeter = 0
458
+ # check 0 > 1 > 2 > 3
459
+ square_length = []
460
+
461
+ for start_idx in range(4):
462
+ end_idx = (start_idx + 1) % 4
463
+
464
+ connect_idx = connects[start_idx] # segment idx of segment01
465
+ start_segments = segments[start_idx]
466
+ end_segments = segments[end_idx]
467
+
468
+ start_point = square[start_idx]
469
+ end_point = square[end_idx]
470
+
471
+ # check whether outside or inside
472
+ start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments,
473
+ connect_idx)
474
+ end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx)
475
+
476
+ cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min
477
+ perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min
478
+
479
+ square_length.append(
480
+ dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min)
481
+
482
+ overlap_scores.append(cover / perimeter)
483
+ ######################################
484
+ ###################################### DEGREE SCORES
485
+ '''
486
+ deg0 vs deg2
487
+ deg1 vs deg3
488
+ '''
489
+ deg0, deg1, deg2, deg3 = degree
490
+ deg_ratio1 = deg0 / deg2
491
+ if deg_ratio1 > 1.0:
492
+ deg_ratio1 = 1 / deg_ratio1
493
+ deg_ratio2 = deg1 / deg3
494
+ if deg_ratio2 > 1.0:
495
+ deg_ratio2 = 1 / deg_ratio2
496
+ degree_scores.append((deg_ratio1 + deg_ratio2) / 2)
497
+ ######################################
498
+ ###################################### LENGTH SCORES
499
+ '''
500
+ len0 vs len2
501
+ len1 vs len3
502
+ '''
503
+ len0, len1, len2, len3 = square_length
504
+ len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0
505
+ len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1
506
+ length_scores.append((len_ratio1 + len_ratio2) / 2)
507
+
508
+ ######################################
509
+
510
+ overlap_scores = np.array(overlap_scores)
511
+ overlap_scores /= np.max(overlap_scores)
512
+
513
+ degree_scores = np.array(degree_scores)
514
+ # degree_scores /= np.max(degree_scores)
515
+
516
+ length_scores = np.array(length_scores)
517
+
518
+ ###################################### AREA SCORES
519
+ area_scores = np.reshape(squares, [-1, 4, 2])
520
+ area_x = area_scores[:, :, 0]
521
+ area_y = area_scores[:, :, 1]
522
+ correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0]
523
+ area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1)
524
+ area_scores = 0.5 * np.abs(area_scores + correction)
525
+ area_scores /= (map_size * map_size) # np.max(area_scores)
526
+ ######################################
527
+
528
+ ###################################### CENTER SCORES
529
+ centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2]
530
+ # squares: [n, 4, 2]
531
+ square_centers = np.mean(squares, axis=1) # [n, 2]
532
+ center2center = np.sqrt(np.sum((centers - square_centers) ** 2))
533
+ center_scores = center2center / (map_size / np.sqrt(2.0))
534
+
535
+ '''
536
+ score_w = [overlap, degree, area, center, length]
537
+ '''
538
+ score_w = [0.0, 1.0, 10.0, 0.5, 1.0]
539
+ score_array = params['w_overlap'] * overlap_scores \
540
+ + params['w_degree'] * degree_scores \
541
+ + params['w_area'] * area_scores \
542
+ - params['w_center'] * center_scores \
543
+ + params['w_length'] * length_scores
544
+
545
+ best_square = []
546
+
547
+ sorted_idx = np.argsort(score_array)[::-1]
548
+ score_array = score_array[sorted_idx]
549
+ squares = squares[sorted_idx]
550
+
551
+ except Exception as e:
552
+ pass
553
+
554
+ '''return list
555
+ merged_lines, squares, scores
556
+ '''
557
+
558
+ try:
559
+ new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1]
560
+ new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0]
561
+ new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1]
562
+ new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0]
563
+ except:
564
+ new_segments = []
565
+
566
+ try:
567
+ squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1]
568
+ squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0]
569
+ except:
570
+ squares = []
571
+ score_array = []
572
+
573
+ try:
574
+ inter_points = np.array(inter_points)
575
+ inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1]
576
+ inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0]
577
+ except:
578
+ inter_points = []
579
+
580
+ return new_segments, squares, score_array, inter_points
src/flux/annotator/tile/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import cv2
3
+ from .guided_filter import FastGuidedFilter
4
+
5
+
6
+ class TileDetector:
7
+ # https://huggingface.co/xinsir/controlnet-tile-sdxl-1.0
8
+ def __init__(self):
9
+ pass
10
+
11
+ def __call__(self, image):
12
+ blur_strength = random.sample([i / 10. for i in range(10, 201, 2)], k=1)[0]
13
+ radius = random.sample([i for i in range(1, 40, 2)], k=1)[0]
14
+ eps = random.sample([i / 1000. for i in range(1, 101, 2)], k=1)[0]
15
+ scale_factor = random.sample([i / 10. for i in range(10, 181, 5)], k=1)[0]
16
+
17
+ ksize = int(blur_strength)
18
+ if ksize % 2 == 0:
19
+ ksize += 1
20
+
21
+ if random.random() > 0.5:
22
+ image = cv2.GaussianBlur(image, (ksize, ksize), blur_strength / 2)
23
+ if random.random() > 0.5:
24
+ filter = FastGuidedFilter(image, radius, eps, scale_factor)
25
+ image = filter.filter(image)
26
+ return image
src/flux/annotator/tile/guided_filter.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ ## @package guided_filter.core.filters
3
+ #
4
+ # Implementation of guided filter.
5
+ # * GuidedFilter: Original guided filter.
6
+ # * FastGuidedFilter: Fast version of the guided filter.
7
+ # @author tody
8
+ # @date 2015/08/26
9
+
10
+ import numpy as np
11
+ import cv2
12
+
13
+ ## Convert image into float32 type.
14
+ def to32F(img):
15
+ if img.dtype == np.float32:
16
+ return img
17
+ return (1.0 / 255.0) * np.float32(img)
18
+
19
+ ## Convert image into uint8 type.
20
+ def to8U(img):
21
+ if img.dtype == np.uint8:
22
+ return img
23
+ return np.clip(np.uint8(255.0 * img), 0, 255)
24
+
25
+ ## Return if the input image is gray or not.
26
+ def _isGray(I):
27
+ return len(I.shape) == 2
28
+
29
+
30
+ ## Return down sampled image.
31
+ # @param scale (w/s, h/s) image will be created.
32
+ # @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
33
+ def _downSample(I, scale=4, shape=None):
34
+ if shape is not None:
35
+ h, w = shape
36
+ return cv2.resize(I, (w, h), interpolation=cv2.INTER_NEAREST)
37
+
38
+ h, w = I.shape[:2]
39
+ return cv2.resize(I, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_NEAREST)
40
+
41
+
42
+ ## Return up sampled image.
43
+ # @param scale (w*s, h*s) image will be created.
44
+ # @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
45
+ def _upSample(I, scale=2, shape=None):
46
+ if shape is not None:
47
+ h, w = shape
48
+ return cv2.resize(I, (w, h), interpolation=cv2.INTER_LINEAR)
49
+
50
+ h, w = I.shape[:2]
51
+ return cv2.resize(I, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
52
+
53
+ ## Fast guide filter.
54
+ class FastGuidedFilter:
55
+ ## Constructor.
56
+ # @param I Input guidance image. Color or gray.
57
+ # @param radius Radius of Guided Filter.
58
+ # @param epsilon Regularization term of Guided Filter.
59
+ # @param scale Down sampled scale.
60
+ def __init__(self, I, radius=5, epsilon=0.4, scale=4):
61
+ I_32F = to32F(I)
62
+ self._I = I_32F
63
+ h, w = I.shape[:2]
64
+
65
+ I_sub = _downSample(I_32F, scale)
66
+
67
+ self._I_sub = I_sub
68
+ radius = int(radius / scale)
69
+
70
+ if _isGray(I):
71
+ self._guided_filter = GuidedFilterGray(I_sub, radius, epsilon)
72
+ else:
73
+ self._guided_filter = GuidedFilterColor(I_sub, radius, epsilon)
74
+
75
+ ## Apply filter for the input image.
76
+ # @param p Input image for the filtering.
77
+ def filter(self, p):
78
+ p_32F = to32F(p)
79
+ shape_original = p.shape[:2]
80
+
81
+ p_sub = _downSample(p_32F, shape=self._I_sub.shape[:2])
82
+
83
+ if _isGray(p_sub):
84
+ return self._filterGray(p_sub, shape_original)
85
+
86
+ cs = p.shape[2]
87
+ q = np.array(p_32F)
88
+
89
+ for ci in range(cs):
90
+ q[:, :, ci] = self._filterGray(p_sub[:, :, ci], shape_original)
91
+ return to8U(q)
92
+
93
+ def _filterGray(self, p_sub, shape_original):
94
+ ab_sub = self._guided_filter._computeCoefficients(p_sub)
95
+ ab = [_upSample(abi, shape=shape_original) for abi in ab_sub]
96
+ return self._guided_filter._computeOutput(ab, self._I)
97
+
98
+
99
+ ## Guide filter.
100
+ class GuidedFilter:
101
+ ## Constructor.
102
+ # @param I Input guidance image. Color or gray.
103
+ # @param radius Radius of Guided Filter.
104
+ # @param epsilon Regularization term of Guided Filter.
105
+ def __init__(self, I, radius=5, epsilon=0.4):
106
+ I_32F = to32F(I)
107
+
108
+ if _isGray(I):
109
+ self._guided_filter = GuidedFilterGray(I_32F, radius, epsilon)
110
+ else:
111
+ self._guided_filter = GuidedFilterColor(I_32F, radius, epsilon)
112
+
113
+ ## Apply filter for the input image.
114
+ # @param p Input image for the filtering.
115
+ def filter(self, p):
116
+ return to8U(self._guided_filter.filter(p))
117
+
118
+
119
+ ## Common parts of guided filter.
120
+ #
121
+ # This class is used by guided_filter class. GuidedFilterGray and GuidedFilterColor.
122
+ # Based on guided_filter._computeCoefficients, guided_filter._computeOutput,
123
+ # GuidedFilterCommon.filter computes filtered image for color and gray.
124
+ class GuidedFilterCommon:
125
+ def __init__(self, guided_filter):
126
+ self._guided_filter = guided_filter
127
+
128
+ ## Apply filter for the input image.
129
+ # @param p Input image for the filtering.
130
+ def filter(self, p):
131
+ p_32F = to32F(p)
132
+ if _isGray(p_32F):
133
+ return self._filterGray(p_32F)
134
+
135
+ cs = p.shape[2]
136
+ q = np.array(p_32F)
137
+
138
+ for ci in range(cs):
139
+ q[:, :, ci] = self._filterGray(p_32F[:, :, ci])
140
+ return q
141
+
142
+ def _filterGray(self, p):
143
+ ab = self._guided_filter._computeCoefficients(p)
144
+ return self._guided_filter._computeOutput(ab, self._guided_filter._I)
145
+
146
+
147
+ ## Guided filter for gray guidance image.
148
+ class GuidedFilterGray:
149
+ # @param I Input gray guidance image.
150
+ # @param radius Radius of Guided Filter.
151
+ # @param epsilon Regularization term of Guided Filter.
152
+ def __init__(self, I, radius=5, epsilon=0.4):
153
+ self._radius = 2 * radius + 1
154
+ self._epsilon = epsilon
155
+ self._I = to32F(I)
156
+ self._initFilter()
157
+ self._filter_common = GuidedFilterCommon(self)
158
+
159
+ ## Apply filter for the input image.
160
+ # @param p Input image for the filtering.
161
+ def filter(self, p):
162
+ return self._filter_common.filter(p)
163
+
164
+ def _initFilter(self):
165
+ I = self._I
166
+ r = self._radius
167
+ self._I_mean = cv2.blur(I, (r, r))
168
+ I_mean_sq = cv2.blur(I ** 2, (r, r))
169
+ self._I_var = I_mean_sq - self._I_mean ** 2
170
+
171
+ def _computeCoefficients(self, p):
172
+ r = self._radius
173
+ p_mean = cv2.blur(p, (r, r))
174
+ p_cov = p_mean - self._I_mean * p_mean
175
+ a = p_cov / (self._I_var + self._epsilon)
176
+ b = p_mean - a * self._I_mean
177
+ a_mean = cv2.blur(a, (r, r))
178
+ b_mean = cv2.blur(b, (r, r))
179
+ return a_mean, b_mean
180
+
181
+ def _computeOutput(self, ab, I):
182
+ a_mean, b_mean = ab
183
+ return a_mean * I + b_mean
184
+
185
+
186
+ ## Guided filter for color guidance image.
187
+ class GuidedFilterColor:
188
+ # @param I Input color guidance image.
189
+ # @param radius Radius of Guided Filter.
190
+ # @param epsilon Regularization term of Guided Filter.
191
+ def __init__(self, I, radius=5, epsilon=0.2):
192
+ self._radius = 2 * radius + 1
193
+ self._epsilon = epsilon
194
+ self._I = to32F(I)
195
+ self._initFilter()
196
+ self._filter_common = GuidedFilterCommon(self)
197
+
198
+ ## Apply filter for the input image.
199
+ # @param p Input image for the filtering.
200
+ def filter(self, p):
201
+ return self._filter_common.filter(p)
202
+
203
+ def _initFilter(self):
204
+ I = self._I
205
+ r = self._radius
206
+ eps = self._epsilon
207
+
208
+ Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
209
+
210
+ self._Ir_mean = cv2.blur(Ir, (r, r))
211
+ self._Ig_mean = cv2.blur(Ig, (r, r))
212
+ self._Ib_mean = cv2.blur(Ib, (r, r))
213
+
214
+ Irr_var = cv2.blur(Ir ** 2, (r, r)) - self._Ir_mean ** 2 + eps
215
+ Irg_var = cv2.blur(Ir * Ig, (r, r)) - self._Ir_mean * self._Ig_mean
216
+ Irb_var = cv2.blur(Ir * Ib, (r, r)) - self._Ir_mean * self._Ib_mean
217
+ Igg_var = cv2.blur(Ig * Ig, (r, r)) - self._Ig_mean * self._Ig_mean + eps
218
+ Igb_var = cv2.blur(Ig * Ib, (r, r)) - self._Ig_mean * self._Ib_mean
219
+ Ibb_var = cv2.blur(Ib * Ib, (r, r)) - self._Ib_mean * self._Ib_mean + eps
220
+
221
+ Irr_inv = Igg_var * Ibb_var - Igb_var * Igb_var
222
+ Irg_inv = Igb_var * Irb_var - Irg_var * Ibb_var
223
+ Irb_inv = Irg_var * Igb_var - Igg_var * Irb_var
224
+ Igg_inv = Irr_var * Ibb_var - Irb_var * Irb_var
225
+ Igb_inv = Irb_var * Irg_var - Irr_var * Igb_var
226
+ Ibb_inv = Irr_var * Igg_var - Irg_var * Irg_var
227
+
228
+ I_cov = Irr_inv * Irr_var + Irg_inv * Irg_var + Irb_inv * Irb_var
229
+ Irr_inv /= I_cov
230
+ Irg_inv /= I_cov
231
+ Irb_inv /= I_cov
232
+ Igg_inv /= I_cov
233
+ Igb_inv /= I_cov
234
+ Ibb_inv /= I_cov
235
+
236
+ self._Irr_inv = Irr_inv
237
+ self._Irg_inv = Irg_inv
238
+ self._Irb_inv = Irb_inv
239
+ self._Igg_inv = Igg_inv
240
+ self._Igb_inv = Igb_inv
241
+ self._Ibb_inv = Ibb_inv
242
+
243
+ def _computeCoefficients(self, p):
244
+ r = self._radius
245
+ I = self._I
246
+ Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
247
+
248
+ p_mean = cv2.blur(p, (r, r))
249
+
250
+ Ipr_mean = cv2.blur(Ir * p, (r, r))
251
+ Ipg_mean = cv2.blur(Ig * p, (r, r))
252
+ Ipb_mean = cv2.blur(Ib * p, (r, r))
253
+
254
+ Ipr_cov = Ipr_mean - self._Ir_mean * p_mean
255
+ Ipg_cov = Ipg_mean - self._Ig_mean * p_mean
256
+ Ipb_cov = Ipb_mean - self._Ib_mean * p_mean
257
+
258
+ ar = self._Irr_inv * Ipr_cov + self._Irg_inv * Ipg_cov + self._Irb_inv * Ipb_cov
259
+ ag = self._Irg_inv * Ipr_cov + self._Igg_inv * Ipg_cov + self._Igb_inv * Ipb_cov
260
+ ab = self._Irb_inv * Ipr_cov + self._Igb_inv * Ipg_cov + self._Ibb_inv * Ipb_cov
261
+ b = p_mean - ar * self._Ir_mean - ag * self._Ig_mean - ab * self._Ib_mean
262
+
263
+ ar_mean = cv2.blur(ar, (r, r))
264
+ ag_mean = cv2.blur(ag, (r, r))
265
+ ab_mean = cv2.blur(ab, (r, r))
266
+ b_mean = cv2.blur(b, (r, r))
267
+
268
+ return ar_mean, ag_mean, ab_mean, b_mean
269
+
270
+ def _computeOutput(self, ab, I):
271
+ ar_mean, ag_mean, ab_mean, b_mean = ab
272
+
273
+ Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
274
+
275
+ q = (ar_mean * Ir +
276
+ ag_mean * Ig +
277
+ ab_mean * Ib +
278
+ b_mean)
279
+
280
+ return q
src/flux/annotator/util.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import os
4
+
5
+
6
+ annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
7
+
8
+
9
+ def HWC3(x):
10
+ assert x.dtype == np.uint8
11
+ if x.ndim == 2:
12
+ x = x[:, :, None]
13
+ assert x.ndim == 3
14
+ H, W, C = x.shape
15
+ assert C == 1 or C == 3 or C == 4
16
+ if C == 3:
17
+ return x
18
+ if C == 1:
19
+ return np.concatenate([x, x, x], axis=2)
20
+ if C == 4:
21
+ color = x[:, :, 0:3].astype(np.float32)
22
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
23
+ y = color * alpha + 255.0 * (1.0 - alpha)
24
+ y = y.clip(0, 255).astype(np.uint8)
25
+ return y
26
+
27
+
28
+ def resize_image(input_image, resolution):
29
+ H, W, C = input_image.shape
30
+ H = float(H)
31
+ W = float(W)
32
+ k = float(resolution) / min(H, W)
33
+ H *= k
34
+ W *= k
35
+ H = int(np.round(H / 64.0)) * 64
36
+ W = int(np.round(W / 64.0)) * 64
37
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
38
+ return img
src/flux/annotator/zoe/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
src/flux/annotator/zoe/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ZoeDepth
2
+ # https://github.com/isl-org/ZoeDepth
3
+
4
+ import os
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+
9
+ from einops import rearrange
10
+ from .zoedepth.models.zoedepth.zoedepth_v1 import ZoeDepth
11
+ from .zoedepth.utils.config import get_config
12
+ from ...annotator.util import annotator_ckpts_path
13
+ from huggingface_hub import hf_hub_download
14
+
15
+
16
+ class ZoeDetector:
17
+ def __init__(self):
18
+ model_path = os.path.join(annotator_ckpts_path, "ZoeD_M12_N.pt")
19
+ if not os.path.exists(model_path):
20
+ model_path = hf_hub_download("lllyasviel/Annotators", "ZoeD_M12_N.pt")
21
+ conf = get_config("zoedepth", "infer")
22
+ model = ZoeDepth.build_from_config(conf)
23
+ model.load_state_dict(torch.load(model_path)['model'], strict=False)
24
+ model = model.cuda()
25
+ model.device = 'cuda'
26
+ model.eval()
27
+ self.model = model
28
+
29
+ def __call__(self, input_image):
30
+ assert input_image.ndim == 3
31
+ image_depth = input_image
32
+ with torch.no_grad():
33
+ image_depth = torch.from_numpy(image_depth).float().cuda()
34
+ image_depth = image_depth / 255.0
35
+ image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
36
+ depth = self.model.infer(image_depth)
37
+
38
+ depth = depth[0, 0].cpu().numpy()
39
+
40
+ vmin = np.percentile(depth, 2)
41
+ vmax = np.percentile(depth, 85)
42
+
43
+ depth -= vmin
44
+ depth /= vmax - vmin
45
+ depth = 1.0 - depth
46
+ depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)
47
+
48
+ return depth_image
src/flux/annotator/zoe/zoedepth/data/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
src/flux/annotator/zoe/zoedepth/data/data_mono.py ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ # This file is partly inspired from BTS (https://github.com/cleinc/bts/blob/master/pytorch/bts_dataloader.py); author: Jin Han Lee
26
+
27
+ import itertools
28
+ import os
29
+ import random
30
+
31
+ import numpy as np
32
+ import cv2
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.utils.data.distributed
36
+ from zoedepth.utils.easydict import EasyDict as edict
37
+ from PIL import Image, ImageOps
38
+ from torch.utils.data import DataLoader, Dataset
39
+ from torchvision import transforms
40
+
41
+ from zoedepth.utils.config import change_dataset
42
+
43
+ from .ddad import get_ddad_loader
44
+ from .diml_indoor_test import get_diml_indoor_loader
45
+ from .diml_outdoor_test import get_diml_outdoor_loader
46
+ from .diode import get_diode_loader
47
+ from .hypersim import get_hypersim_loader
48
+ from .ibims import get_ibims_loader
49
+ from .sun_rgbd_loader import get_sunrgbd_loader
50
+ from .vkitti import get_vkitti_loader
51
+ from .vkitti2 import get_vkitti2_loader
52
+
53
+ from .preprocess import CropParams, get_white_border, get_black_border
54
+
55
+
56
+ def _is_pil_image(img):
57
+ return isinstance(img, Image.Image)
58
+
59
+
60
+ def _is_numpy_image(img):
61
+ return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
62
+
63
+
64
+ def preprocessing_transforms(mode, **kwargs):
65
+ return transforms.Compose([
66
+ ToTensor(mode=mode, **kwargs)
67
+ ])
68
+
69
+
70
+ class DepthDataLoader(object):
71
+ def __init__(self, config, mode, device='cpu', transform=None, **kwargs):
72
+ """
73
+ Data loader for depth datasets
74
+
75
+ Args:
76
+ config (dict): Config dictionary. Refer to utils/config.py
77
+ mode (str): "train" or "online_eval"
78
+ device (str, optional): Device to load the data on. Defaults to 'cpu'.
79
+ transform (torchvision.transforms, optional): Transform to apply to the data. Defaults to None.
80
+ """
81
+
82
+ self.config = config
83
+
84
+ if config.dataset == 'ibims':
85
+ self.data = get_ibims_loader(config, batch_size=1, num_workers=1)
86
+ return
87
+
88
+ if config.dataset == 'sunrgbd':
89
+ self.data = get_sunrgbd_loader(
90
+ data_dir_root=config.sunrgbd_root, batch_size=1, num_workers=1)
91
+ return
92
+
93
+ if config.dataset == 'diml_indoor':
94
+ self.data = get_diml_indoor_loader(
95
+ data_dir_root=config.diml_indoor_root, batch_size=1, num_workers=1)
96
+ return
97
+
98
+ if config.dataset == 'diml_outdoor':
99
+ self.data = get_diml_outdoor_loader(
100
+ data_dir_root=config.diml_outdoor_root, batch_size=1, num_workers=1)
101
+ return
102
+
103
+ if "diode" in config.dataset:
104
+ self.data = get_diode_loader(
105
+ config[config.dataset+"_root"], batch_size=1, num_workers=1)
106
+ return
107
+
108
+ if config.dataset == 'hypersim_test':
109
+ self.data = get_hypersim_loader(
110
+ config.hypersim_test_root, batch_size=1, num_workers=1)
111
+ return
112
+
113
+ if config.dataset == 'vkitti':
114
+ self.data = get_vkitti_loader(
115
+ config.vkitti_root, batch_size=1, num_workers=1)
116
+ return
117
+
118
+ if config.dataset == 'vkitti2':
119
+ self.data = get_vkitti2_loader(
120
+ config.vkitti2_root, batch_size=1, num_workers=1)
121
+ return
122
+
123
+ if config.dataset == 'ddad':
124
+ self.data = get_ddad_loader(config.ddad_root, resize_shape=(
125
+ 352, 1216), batch_size=1, num_workers=1)
126
+ return
127
+
128
+ img_size = self.config.get("img_size", None)
129
+ img_size = img_size if self.config.get(
130
+ "do_input_resize", False) else None
131
+
132
+ if transform is None:
133
+ transform = preprocessing_transforms(mode, size=img_size)
134
+
135
+ if mode == 'train':
136
+
137
+ Dataset = DataLoadPreprocess
138
+ self.training_samples = Dataset(
139
+ config, mode, transform=transform, device=device)
140
+
141
+ if config.distributed:
142
+ self.train_sampler = torch.utils.data.distributed.DistributedSampler(
143
+ self.training_samples)
144
+ else:
145
+ self.train_sampler = None
146
+
147
+ self.data = DataLoader(self.training_samples,
148
+ batch_size=config.batch_size,
149
+ shuffle=(self.train_sampler is None),
150
+ num_workers=config.workers,
151
+ pin_memory=True,
152
+ persistent_workers=True,
153
+ # prefetch_factor=2,
154
+ sampler=self.train_sampler)
155
+
156
+ elif mode == 'online_eval':
157
+ self.testing_samples = DataLoadPreprocess(
158
+ config, mode, transform=transform)
159
+ if config.distributed: # redundant. here only for readability and to be more explicit
160
+ # Give whole test set to all processes (and report evaluation only on one) regardless
161
+ self.eval_sampler = None
162
+ else:
163
+ self.eval_sampler = None
164
+ self.data = DataLoader(self.testing_samples, 1,
165
+ shuffle=kwargs.get("shuffle_test", False),
166
+ num_workers=1,
167
+ pin_memory=False,
168
+ sampler=self.eval_sampler)
169
+
170
+ elif mode == 'test':
171
+ self.testing_samples = DataLoadPreprocess(
172
+ config, mode, transform=transform)
173
+ self.data = DataLoader(self.testing_samples,
174
+ 1, shuffle=False, num_workers=1)
175
+
176
+ else:
177
+ print(
178
+ 'mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
179
+
180
+
181
+ def repetitive_roundrobin(*iterables):
182
+ """
183
+ cycles through iterables but sample wise
184
+ first yield first sample from first iterable then first sample from second iterable and so on
185
+ then second sample from first iterable then second sample from second iterable and so on
186
+
187
+ If one iterable is shorter than the others, it is repeated until all iterables are exhausted
188
+ repetitive_roundrobin('ABC', 'D', 'EF') --> A D E B D F C D E
189
+ """
190
+ # Repetitive roundrobin
191
+ iterables_ = [iter(it) for it in iterables]
192
+ exhausted = [False] * len(iterables)
193
+ while not all(exhausted):
194
+ for i, it in enumerate(iterables_):
195
+ try:
196
+ yield next(it)
197
+ except StopIteration:
198
+ exhausted[i] = True
199
+ iterables_[i] = itertools.cycle(iterables[i])
200
+ # First elements may get repeated if one iterable is shorter than the others
201
+ yield next(iterables_[i])
202
+
203
+
204
+ class RepetitiveRoundRobinDataLoader(object):
205
+ def __init__(self, *dataloaders):
206
+ self.dataloaders = dataloaders
207
+
208
+ def __iter__(self):
209
+ return repetitive_roundrobin(*self.dataloaders)
210
+
211
+ def __len__(self):
212
+ # First samples get repeated, thats why the plus one
213
+ return len(self.dataloaders) * (max(len(dl) for dl in self.dataloaders) + 1)
214
+
215
+
216
+ class MixedNYUKITTI(object):
217
+ def __init__(self, config, mode, device='cpu', **kwargs):
218
+ config = edict(config)
219
+ config.workers = config.workers // 2
220
+ self.config = config
221
+ nyu_conf = change_dataset(edict(config), 'nyu')
222
+ kitti_conf = change_dataset(edict(config), 'kitti')
223
+
224
+ # make nyu default for testing
225
+ self.config = config = nyu_conf
226
+ img_size = self.config.get("img_size", None)
227
+ img_size = img_size if self.config.get(
228
+ "do_input_resize", False) else None
229
+ if mode == 'train':
230
+ nyu_loader = DepthDataLoader(
231
+ nyu_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
232
+ kitti_loader = DepthDataLoader(
233
+ kitti_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
234
+ # It has been changed to repetitive roundrobin
235
+ self.data = RepetitiveRoundRobinDataLoader(
236
+ nyu_loader, kitti_loader)
237
+ else:
238
+ self.data = DepthDataLoader(nyu_conf, mode, device=device).data
239
+
240
+
241
+ def remove_leading_slash(s):
242
+ if s[0] == '/' or s[0] == '\\':
243
+ return s[1:]
244
+ return s
245
+
246
+
247
+ class CachedReader:
248
+ def __init__(self, shared_dict=None):
249
+ if shared_dict:
250
+ self._cache = shared_dict
251
+ else:
252
+ self._cache = {}
253
+
254
+ def open(self, fpath):
255
+ im = self._cache.get(fpath, None)
256
+ if im is None:
257
+ im = self._cache[fpath] = Image.open(fpath)
258
+ return im
259
+
260
+
261
+ class ImReader:
262
+ def __init__(self):
263
+ pass
264
+
265
+ # @cache
266
+ def open(self, fpath):
267
+ return Image.open(fpath)
268
+
269
+
270
+ class DataLoadPreprocess(Dataset):
271
+ def __init__(self, config, mode, transform=None, is_for_online_eval=False, **kwargs):
272
+ self.config = config
273
+ if mode == 'online_eval':
274
+ with open(config.filenames_file_eval, 'r') as f:
275
+ self.filenames = f.readlines()
276
+ else:
277
+ with open(config.filenames_file, 'r') as f:
278
+ self.filenames = f.readlines()
279
+
280
+ self.mode = mode
281
+ self.transform = transform
282
+ self.to_tensor = ToTensor(mode)
283
+ self.is_for_online_eval = is_for_online_eval
284
+ if config.use_shared_dict:
285
+ self.reader = CachedReader(config.shared_dict)
286
+ else:
287
+ self.reader = ImReader()
288
+
289
+ def postprocess(self, sample):
290
+ return sample
291
+
292
+ def __getitem__(self, idx):
293
+ sample_path = self.filenames[idx]
294
+ focal = float(sample_path.split()[2])
295
+ sample = {}
296
+
297
+ if self.mode == 'train':
298
+ if self.config.dataset == 'kitti' and self.config.use_right and random.random() > 0.5:
299
+ image_path = os.path.join(
300
+ self.config.data_path, remove_leading_slash(sample_path.split()[3]))
301
+ depth_path = os.path.join(
302
+ self.config.gt_path, remove_leading_slash(sample_path.split()[4]))
303
+ else:
304
+ image_path = os.path.join(
305
+ self.config.data_path, remove_leading_slash(sample_path.split()[0]))
306
+ depth_path = os.path.join(
307
+ self.config.gt_path, remove_leading_slash(sample_path.split()[1]))
308
+
309
+ image = self.reader.open(image_path)
310
+ depth_gt = self.reader.open(depth_path)
311
+ w, h = image.size
312
+
313
+ if self.config.do_kb_crop:
314
+ height = image.height
315
+ width = image.width
316
+ top_margin = int(height - 352)
317
+ left_margin = int((width - 1216) / 2)
318
+ depth_gt = depth_gt.crop(
319
+ (left_margin, top_margin, left_margin + 1216, top_margin + 352))
320
+ image = image.crop(
321
+ (left_margin, top_margin, left_margin + 1216, top_margin + 352))
322
+
323
+ # Avoid blank boundaries due to pixel registration?
324
+ # Train images have white border. Test images have black border.
325
+ if self.config.dataset == 'nyu' and self.config.avoid_boundary:
326
+ # print("Avoiding Blank Boundaries!")
327
+ # We just crop and pad again with reflect padding to original size
328
+ # original_size = image.size
329
+ crop_params = get_white_border(np.array(image, dtype=np.uint8))
330
+ image = image.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom))
331
+ depth_gt = depth_gt.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom))
332
+
333
+ # Use reflect padding to fill the blank
334
+ image = np.array(image)
335
+ image = np.pad(image, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right), (0, 0)), mode='reflect')
336
+ image = Image.fromarray(image)
337
+
338
+ depth_gt = np.array(depth_gt)
339
+ depth_gt = np.pad(depth_gt, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right)), 'constant', constant_values=0)
340
+ depth_gt = Image.fromarray(depth_gt)
341
+
342
+
343
+ if self.config.do_random_rotate and (self.config.aug):
344
+ random_angle = (random.random() - 0.5) * 2 * self.config.degree
345
+ image = self.rotate_image(image, random_angle)
346
+ depth_gt = self.rotate_image(
347
+ depth_gt, random_angle, flag=Image.NEAREST)
348
+
349
+ image = np.asarray(image, dtype=np.float32) / 255.0
350
+ depth_gt = np.asarray(depth_gt, dtype=np.float32)
351
+ depth_gt = np.expand_dims(depth_gt, axis=2)
352
+
353
+ if self.config.dataset == 'nyu':
354
+ depth_gt = depth_gt / 1000.0
355
+ else:
356
+ depth_gt = depth_gt / 256.0
357
+
358
+ if self.config.aug and (self.config.random_crop):
359
+ image, depth_gt = self.random_crop(
360
+ image, depth_gt, self.config.input_height, self.config.input_width)
361
+
362
+ if self.config.aug and self.config.random_translate:
363
+ # print("Random Translation!")
364
+ image, depth_gt = self.random_translate(image, depth_gt, self.config.max_translation)
365
+
366
+ image, depth_gt = self.train_preprocess(image, depth_gt)
367
+ mask = np.logical_and(depth_gt > self.config.min_depth,
368
+ depth_gt < self.config.max_depth).squeeze()[None, ...]
369
+ sample = {'image': image, 'depth': depth_gt, 'focal': focal,
370
+ 'mask': mask, **sample}
371
+
372
+ else:
373
+ if self.mode == 'online_eval':
374
+ data_path = self.config.data_path_eval
375
+ else:
376
+ data_path = self.config.data_path
377
+
378
+ image_path = os.path.join(
379
+ data_path, remove_leading_slash(sample_path.split()[0]))
380
+ image = np.asarray(self.reader.open(image_path),
381
+ dtype=np.float32) / 255.0
382
+
383
+ if self.mode == 'online_eval':
384
+ gt_path = self.config.gt_path_eval
385
+ depth_path = os.path.join(
386
+ gt_path, remove_leading_slash(sample_path.split()[1]))
387
+ has_valid_depth = False
388
+ try:
389
+ depth_gt = self.reader.open(depth_path)
390
+ has_valid_depth = True
391
+ except IOError:
392
+ depth_gt = False
393
+ # print('Missing gt for {}'.format(image_path))
394
+
395
+ if has_valid_depth:
396
+ depth_gt = np.asarray(depth_gt, dtype=np.float32)
397
+ depth_gt = np.expand_dims(depth_gt, axis=2)
398
+ if self.config.dataset == 'nyu':
399
+ depth_gt = depth_gt / 1000.0
400
+ else:
401
+ depth_gt = depth_gt / 256.0
402
+
403
+ mask = np.logical_and(
404
+ depth_gt >= self.config.min_depth, depth_gt <= self.config.max_depth).squeeze()[None, ...]
405
+ else:
406
+ mask = False
407
+
408
+ if self.config.do_kb_crop:
409
+ height = image.shape[0]
410
+ width = image.shape[1]
411
+ top_margin = int(height - 352)
412
+ left_margin = int((width - 1216) / 2)
413
+ image = image[top_margin:top_margin + 352,
414
+ left_margin:left_margin + 1216, :]
415
+ if self.mode == 'online_eval' and has_valid_depth:
416
+ depth_gt = depth_gt[top_margin:top_margin +
417
+ 352, left_margin:left_margin + 1216, :]
418
+
419
+ if self.mode == 'online_eval':
420
+ sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth,
421
+ 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1],
422
+ 'mask': mask}
423
+ else:
424
+ sample = {'image': image, 'focal': focal}
425
+
426
+ if (self.mode == 'train') or ('has_valid_depth' in sample and sample['has_valid_depth']):
427
+ mask = np.logical_and(depth_gt > self.config.min_depth,
428
+ depth_gt < self.config.max_depth).squeeze()[None, ...]
429
+ sample['mask'] = mask
430
+
431
+ if self.transform:
432
+ sample = self.transform(sample)
433
+
434
+ sample = self.postprocess(sample)
435
+ sample['dataset'] = self.config.dataset
436
+ sample = {**sample, 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1]}
437
+
438
+ return sample
439
+
440
+ def rotate_image(self, image, angle, flag=Image.BILINEAR):
441
+ result = image.rotate(angle, resample=flag)
442
+ return result
443
+
444
+ def random_crop(self, img, depth, height, width):
445
+ assert img.shape[0] >= height
446
+ assert img.shape[1] >= width
447
+ assert img.shape[0] == depth.shape[0]
448
+ assert img.shape[1] == depth.shape[1]
449
+ x = random.randint(0, img.shape[1] - width)
450
+ y = random.randint(0, img.shape[0] - height)
451
+ img = img[y:y + height, x:x + width, :]
452
+ depth = depth[y:y + height, x:x + width, :]
453
+
454
+ return img, depth
455
+
456
+ def random_translate(self, img, depth, max_t=20):
457
+ assert img.shape[0] == depth.shape[0]
458
+ assert img.shape[1] == depth.shape[1]
459
+ p = self.config.translate_prob
460
+ do_translate = random.random()
461
+ if do_translate > p:
462
+ return img, depth
463
+ x = random.randint(-max_t, max_t)
464
+ y = random.randint(-max_t, max_t)
465
+ M = np.float32([[1, 0, x], [0, 1, y]])
466
+ # print(img.shape, depth.shape)
467
+ img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
468
+ depth = cv2.warpAffine(depth, M, (depth.shape[1], depth.shape[0]))
469
+ depth = depth.squeeze()[..., None] # add channel dim back. Affine warp removes it
470
+ # print("after", img.shape, depth.shape)
471
+ return img, depth
472
+
473
+ def train_preprocess(self, image, depth_gt):
474
+ if self.config.aug:
475
+ # Random flipping
476
+ do_flip = random.random()
477
+ if do_flip > 0.5:
478
+ image = (image[:, ::-1, :]).copy()
479
+ depth_gt = (depth_gt[:, ::-1, :]).copy()
480
+
481
+ # Random gamma, brightness, color augmentation
482
+ do_augment = random.random()
483
+ if do_augment > 0.5:
484
+ image = self.augment_image(image)
485
+
486
+ return image, depth_gt
487
+
488
+ def augment_image(self, image):
489
+ # gamma augmentation
490
+ gamma = random.uniform(0.9, 1.1)
491
+ image_aug = image ** gamma
492
+
493
+ # brightness augmentation
494
+ if self.config.dataset == 'nyu':
495
+ brightness = random.uniform(0.75, 1.25)
496
+ else:
497
+ brightness = random.uniform(0.9, 1.1)
498
+ image_aug = image_aug * brightness
499
+
500
+ # color augmentation
501
+ colors = np.random.uniform(0.9, 1.1, size=3)
502
+ white = np.ones((image.shape[0], image.shape[1]))
503
+ color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
504
+ image_aug *= color_image
505
+ image_aug = np.clip(image_aug, 0, 1)
506
+
507
+ return image_aug
508
+
509
+ def __len__(self):
510
+ return len(self.filenames)
511
+
512
+
513
+ class ToTensor(object):
514
+ def __init__(self, mode, do_normalize=False, size=None):
515
+ self.mode = mode
516
+ self.normalize = transforms.Normalize(
517
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if do_normalize else nn.Identity()
518
+ self.size = size
519
+ if size is not None:
520
+ self.resize = transforms.Resize(size=size)
521
+ else:
522
+ self.resize = nn.Identity()
523
+
524
+ def __call__(self, sample):
525
+ image, focal = sample['image'], sample['focal']
526
+ image = self.to_tensor(image)
527
+ image = self.normalize(image)
528
+ image = self.resize(image)
529
+
530
+ if self.mode == 'test':
531
+ return {'image': image, 'focal': focal}
532
+
533
+ depth = sample['depth']
534
+ if self.mode == 'train':
535
+ depth = self.to_tensor(depth)
536
+ return {**sample, 'image': image, 'depth': depth, 'focal': focal}
537
+ else:
538
+ has_valid_depth = sample['has_valid_depth']
539
+ image = self.resize(image)
540
+ return {**sample, 'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth,
541
+ 'image_path': sample['image_path'], 'depth_path': sample['depth_path']}
542
+
543
+ def to_tensor(self, pic):
544
+ if not (_is_pil_image(pic) or _is_numpy_image(pic)):
545
+ raise TypeError(
546
+ 'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
547
+
548
+ if isinstance(pic, np.ndarray):
549
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
550
+ return img
551
+
552
+ # handle PIL Image
553
+ if pic.mode == 'I':
554
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
555
+ elif pic.mode == 'I;16':
556
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
557
+ else:
558
+ img = torch.ByteTensor(
559
+ torch.ByteStorage.from_buffer(pic.tobytes()))
560
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
561
+ if pic.mode == 'YCbCr':
562
+ nchannel = 3
563
+ elif pic.mode == 'I;16':
564
+ nchannel = 1
565
+ else:
566
+ nchannel = len(pic.mode)
567
+ img = img.view(pic.size[1], pic.size[0], nchannel)
568
+
569
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
570
+ if isinstance(img, torch.ByteTensor):
571
+ return img.float()
572
+ else:
573
+ return img
src/flux/annotator/zoe/zoedepth/data/ddad.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms
32
+
33
+
34
+ class ToTensor(object):
35
+ def __init__(self, resize_shape):
36
+ # self.normalize = transforms.Normalize(
37
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+ self.normalize = lambda x : x
39
+ self.resize = transforms.Resize(resize_shape)
40
+
41
+ def __call__(self, sample):
42
+ image, depth = sample['image'], sample['depth']
43
+ image = self.to_tensor(image)
44
+ image = self.normalize(image)
45
+ depth = self.to_tensor(depth)
46
+
47
+ image = self.resize(image)
48
+
49
+ return {'image': image, 'depth': depth, 'dataset': "ddad"}
50
+
51
+ def to_tensor(self, pic):
52
+
53
+ if isinstance(pic, np.ndarray):
54
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
55
+ return img
56
+
57
+ # # handle PIL Image
58
+ if pic.mode == 'I':
59
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
60
+ elif pic.mode == 'I;16':
61
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
62
+ else:
63
+ img = torch.ByteTensor(
64
+ torch.ByteStorage.from_buffer(pic.tobytes()))
65
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
66
+ if pic.mode == 'YCbCr':
67
+ nchannel = 3
68
+ elif pic.mode == 'I;16':
69
+ nchannel = 1
70
+ else:
71
+ nchannel = len(pic.mode)
72
+ img = img.view(pic.size[1], pic.size[0], nchannel)
73
+
74
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
75
+
76
+ if isinstance(img, torch.ByteTensor):
77
+ return img.float()
78
+ else:
79
+ return img
80
+
81
+
82
+ class DDAD(Dataset):
83
+ def __init__(self, data_dir_root, resize_shape):
84
+ import glob
85
+
86
+ # image paths are of the form <data_dir_root>/{outleft, depthmap}/*.png
87
+ self.image_files = glob.glob(os.path.join(data_dir_root, '*.png'))
88
+ self.depth_files = [r.replace("_rgb.png", "_depth.npy")
89
+ for r in self.image_files]
90
+ self.transform = ToTensor(resize_shape)
91
+
92
+ def __getitem__(self, idx):
93
+
94
+ image_path = self.image_files[idx]
95
+ depth_path = self.depth_files[idx]
96
+
97
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
98
+ depth = np.load(depth_path) # meters
99
+
100
+ # depth[depth > 8] = -1
101
+ depth = depth[..., None]
102
+
103
+ sample = dict(image=image, depth=depth)
104
+ sample = self.transform(sample)
105
+
106
+ if idx == 0:
107
+ print(sample["image"].shape)
108
+
109
+ return sample
110
+
111
+ def __len__(self):
112
+ return len(self.image_files)
113
+
114
+
115
+ def get_ddad_loader(data_dir_root, resize_shape, batch_size=1, **kwargs):
116
+ dataset = DDAD(data_dir_root, resize_shape)
117
+ return DataLoader(dataset, batch_size, **kwargs)
src/flux/annotator/zoe/zoedepth/data/diml_indoor_test.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms
32
+
33
+
34
+ class ToTensor(object):
35
+ def __init__(self):
36
+ # self.normalize = transforms.Normalize(
37
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+ self.normalize = lambda x : x
39
+ self.resize = transforms.Resize((480, 640))
40
+
41
+ def __call__(self, sample):
42
+ image, depth = sample['image'], sample['depth']
43
+ image = self.to_tensor(image)
44
+ image = self.normalize(image)
45
+ depth = self.to_tensor(depth)
46
+
47
+ image = self.resize(image)
48
+
49
+ return {'image': image, 'depth': depth, 'dataset': "diml_indoor"}
50
+
51
+ def to_tensor(self, pic):
52
+
53
+ if isinstance(pic, np.ndarray):
54
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
55
+ return img
56
+
57
+ # # handle PIL Image
58
+ if pic.mode == 'I':
59
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
60
+ elif pic.mode == 'I;16':
61
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
62
+ else:
63
+ img = torch.ByteTensor(
64
+ torch.ByteStorage.from_buffer(pic.tobytes()))
65
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
66
+ if pic.mode == 'YCbCr':
67
+ nchannel = 3
68
+ elif pic.mode == 'I;16':
69
+ nchannel = 1
70
+ else:
71
+ nchannel = len(pic.mode)
72
+ img = img.view(pic.size[1], pic.size[0], nchannel)
73
+
74
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
75
+ if isinstance(img, torch.ByteTensor):
76
+ return img.float()
77
+ else:
78
+ return img
79
+
80
+
81
+ class DIML_Indoor(Dataset):
82
+ def __init__(self, data_dir_root):
83
+ import glob
84
+
85
+ # image paths are of the form <data_dir_root>/{HR, LR}/<scene>/{color, depth_filled}/*.png
86
+ self.image_files = glob.glob(os.path.join(
87
+ data_dir_root, "LR", '*', 'color', '*.png'))
88
+ self.depth_files = [r.replace("color", "depth_filled").replace(
89
+ "_c.png", "_depth_filled.png") for r in self.image_files]
90
+ self.transform = ToTensor()
91
+
92
+ def __getitem__(self, idx):
93
+ image_path = self.image_files[idx]
94
+ depth_path = self.depth_files[idx]
95
+
96
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
97
+ depth = np.asarray(Image.open(depth_path),
98
+ dtype='uint16') / 1000.0 # mm to meters
99
+
100
+ # print(np.shape(image))
101
+ # print(np.shape(depth))
102
+
103
+ # depth[depth > 8] = -1
104
+ depth = depth[..., None]
105
+
106
+ sample = dict(image=image, depth=depth)
107
+
108
+ # return sample
109
+ sample = self.transform(sample)
110
+
111
+ if idx == 0:
112
+ print(sample["image"].shape)
113
+
114
+ return sample
115
+
116
+ def __len__(self):
117
+ return len(self.image_files)
118
+
119
+
120
+ def get_diml_indoor_loader(data_dir_root, batch_size=1, **kwargs):
121
+ dataset = DIML_Indoor(data_dir_root)
122
+ return DataLoader(dataset, batch_size, **kwargs)
123
+
124
+ # get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/HR")
125
+ # get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/LR")
src/flux/annotator/zoe/zoedepth/data/diml_outdoor_test.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms
32
+
33
+
34
+ class ToTensor(object):
35
+ def __init__(self):
36
+ # self.normalize = transforms.Normalize(
37
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+ self.normalize = lambda x : x
39
+
40
+ def __call__(self, sample):
41
+ image, depth = sample['image'], sample['depth']
42
+ image = self.to_tensor(image)
43
+ image = self.normalize(image)
44
+ depth = self.to_tensor(depth)
45
+
46
+ return {'image': image, 'depth': depth, 'dataset': "diml_outdoor"}
47
+
48
+ def to_tensor(self, pic):
49
+
50
+ if isinstance(pic, np.ndarray):
51
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
52
+ return img
53
+
54
+ # # handle PIL Image
55
+ if pic.mode == 'I':
56
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
57
+ elif pic.mode == 'I;16':
58
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
59
+ else:
60
+ img = torch.ByteTensor(
61
+ torch.ByteStorage.from_buffer(pic.tobytes()))
62
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
63
+ if pic.mode == 'YCbCr':
64
+ nchannel = 3
65
+ elif pic.mode == 'I;16':
66
+ nchannel = 1
67
+ else:
68
+ nchannel = len(pic.mode)
69
+ img = img.view(pic.size[1], pic.size[0], nchannel)
70
+
71
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
72
+ if isinstance(img, torch.ByteTensor):
73
+ return img.float()
74
+ else:
75
+ return img
76
+
77
+
78
+ class DIML_Outdoor(Dataset):
79
+ def __init__(self, data_dir_root):
80
+ import glob
81
+
82
+ # image paths are of the form <data_dir_root>/{outleft, depthmap}/*.png
83
+ self.image_files = glob.glob(os.path.join(
84
+ data_dir_root, "*", 'outleft', '*.png'))
85
+ self.depth_files = [r.replace("outleft", "depthmap")
86
+ for r in self.image_files]
87
+ self.transform = ToTensor()
88
+
89
+ def __getitem__(self, idx):
90
+ image_path = self.image_files[idx]
91
+ depth_path = self.depth_files[idx]
92
+
93
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
94
+ depth = np.asarray(Image.open(depth_path),
95
+ dtype='uint16') / 1000.0 # mm to meters
96
+
97
+ # depth[depth > 8] = -1
98
+ depth = depth[..., None]
99
+
100
+ sample = dict(image=image, depth=depth, dataset="diml_outdoor")
101
+
102
+ # return sample
103
+ return self.transform(sample)
104
+
105
+ def __len__(self):
106
+ return len(self.image_files)
107
+
108
+
109
+ def get_diml_outdoor_loader(data_dir_root, batch_size=1, **kwargs):
110
+ dataset = DIML_Outdoor(data_dir_root)
111
+ return DataLoader(dataset, batch_size, **kwargs)
112
+
113
+ # get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/HR")
114
+ # get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/LR")
src/flux/annotator/zoe/zoedepth/data/diode.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Intelligent Systems Lab Org
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # File author: Shariq Farooq Bhat
24
+
25
+ import os
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from torch.utils.data import DataLoader, Dataset
31
+ from torchvision import transforms
32
+
33
+
34
+ class ToTensor(object):
35
+ def __init__(self):
36
+ # self.normalize = transforms.Normalize(
37
+ # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+ self.normalize = lambda x : x
39
+ self.resize = transforms.Resize(480)
40
+
41
+ def __call__(self, sample):
42
+ image, depth = sample['image'], sample['depth']
43
+ image = self.to_tensor(image)
44
+ image = self.normalize(image)
45
+ depth = self.to_tensor(depth)
46
+
47
+ image = self.resize(image)
48
+
49
+ return {'image': image, 'depth': depth, 'dataset': "diode"}
50
+
51
+ def to_tensor(self, pic):
52
+
53
+ if isinstance(pic, np.ndarray):
54
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
55
+ return img
56
+
57
+ # # handle PIL Image
58
+ if pic.mode == 'I':
59
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
60
+ elif pic.mode == 'I;16':
61
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
62
+ else:
63
+ img = torch.ByteTensor(
64
+ torch.ByteStorage.from_buffer(pic.tobytes()))
65
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
66
+ if pic.mode == 'YCbCr':
67
+ nchannel = 3
68
+ elif pic.mode == 'I;16':
69
+ nchannel = 1
70
+ else:
71
+ nchannel = len(pic.mode)
72
+ img = img.view(pic.size[1], pic.size[0], nchannel)
73
+
74
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
75
+
76
+ if isinstance(img, torch.ByteTensor):
77
+ return img.float()
78
+ else:
79
+ return img
80
+
81
+
82
+ class DIODE(Dataset):
83
+ def __init__(self, data_dir_root):
84
+ import glob
85
+
86
+ # image paths are of the form <data_dir_root>/scene_#/scan_#/*.png
87
+ self.image_files = glob.glob(
88
+ os.path.join(data_dir_root, '*', '*', '*.png'))
89
+ self.depth_files = [r.replace(".png", "_depth.npy")
90
+ for r in self.image_files]
91
+ self.depth_mask_files = [
92
+ r.replace(".png", "_depth_mask.npy") for r in self.image_files]
93
+ self.transform = ToTensor()
94
+
95
+ def __getitem__(self, idx):
96
+ image_path = self.image_files[idx]
97
+ depth_path = self.depth_files[idx]
98
+ depth_mask_path = self.depth_mask_files[idx]
99
+
100
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
101
+ depth = np.load(depth_path) # in meters
102
+ valid = np.load(depth_mask_path) # binary
103
+
104
+ # depth[depth > 8] = -1
105
+ # depth = depth[..., None]
106
+
107
+ sample = dict(image=image, depth=depth, valid=valid)
108
+
109
+ # return sample
110
+ sample = self.transform(sample)
111
+
112
+ if idx == 0:
113
+ print(sample["image"].shape)
114
+
115
+ return sample
116
+
117
+ def __len__(self):
118
+ return len(self.image_files)
119
+
120
+
121
+ def get_diode_loader(data_dir_root, batch_size=1, **kwargs):
122
+ dataset = DIODE(data_dir_root)
123
+ return DataLoader(dataset, batch_size, **kwargs)
124
+
125
+ # get_diode_loader(data_dir_root="datasets/diode/val/outdoor")