diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..168e810c7d31e26f8abfba9c7f9ed4a6b616098d Binary files /dev/null and b/.DS_Store differ diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..7b2be8e2d15c2b740d63f173b0f9686be99f0309 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/**/* filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..fc484d7f438a95ed9ffd32668b973285661063eb --- /dev/null +++ b/.gitignore @@ -0,0 +1,29 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +assets/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST diff --git a/README.md b/README.md index 90723adfa62d981d7be121989cd82394bc4d4fc6..0017208d8da0d0f05c74596b97264df37ae8dfec 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,14 @@ --- -title: Teset Demo -emoji: 🖼 -colorFrom: purple -colorTo: red +title: ByteMorph Demo +emoji: 📊 +colorFrom: blue +colorTo: purple sdk: gradio -sdk_version: 5.25.2 +sdk_version: 5.31.0 app_file: app.py pinned: false +license: other +short_description: Online Demo for ByteMorph --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py index 652dc459175b17dc4acd2394b8b8ad1edd2394a5..fb38fc428111146fa7be64bd71f8fdf3eb20040d 100644 --- a/app.py +++ b/app.py @@ -1,154 +1,173 @@ import gradio as gr -import numpy as np -import random - -# import spaces #[uncomment to use ZeroGPU] -from diffusers import DiffusionPipeline import torch - -device = "cuda" if torch.cuda.is_available() else "cpu" -model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use - -if torch.cuda.is_available(): - torch_dtype = torch.float16 -else: - torch_dtype = torch.float32 - -pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype) -pipe = pipe.to(device) - -MAX_SEED = np.iinfo(np.int32).max -MAX_IMAGE_SIZE = 1024 - - -# @spaces.GPU #[uncomment to use ZeroGPU] -def infer( - prompt, - negative_prompt, - seed, - randomize_seed, - width, - height, - guidance_scale, - num_inference_steps, - progress=gr.Progress(track_tqdm=True), -): - if randomize_seed: - seed = random.randint(0, MAX_SEED) - - generator = torch.Generator().manual_seed(seed) - - image = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - guidance_scale=guidance_scale, - num_inference_steps=num_inference_steps, - width=width, - height=height, - generator=generator, - ).images[0] - - return image, seed - - -examples = [ - "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", - "An astronaut riding a green horse", - "A delicious ceviche cheesecake slice", -] - -css = """ -#col-container { - margin: 0 auto; - max-width: 640px; -} +import spaces +import os +import numpy as np +from PIL import Image + +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file +from omegaconf import OmegaConf +from src.flux.util import load_ae, load_clip, load_flow_model2, load_t5, tensor_to_pil_image +from src.flux.xflux_pipeline import XFluxSampler +from image_datasets.dataset import image_resize + +# ===== No CUDA/model initialization globally ===== +args = OmegaConf.load("inference_configs/inference.yaml") +is_schnell = args.model_name == "flux-schnell" + +# sampler = None +device = torch.device("cuda") +dtype = torch.bfloat16 +dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype) +vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype) +t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype) +clip = load_clip("cpu").to(device, dtype=dtype) + +vae.requires_grad_(False) +t5.requires_grad_(False) +clip.requires_grad_(False) + +model_path = hf_hub_download( + repo_id="Boese0601/ByteMorpher", + filename="dit.safetensors", + use_auth_token=os.getenv("HF_TOKEN") +) +state_dict = load_file(model_path) +dit.load_state_dict(state_dict) +dit.eval() +dit.to(device, dtype=dtype) + +sampler = XFluxSampler( + clip=clip, + t5=t5, + ae=vae, + model=dit, + device=device, + ip_loaded=False, + spatial_condition=False, + clip_image_processor=None, + image_encoder=None, + improj=None +) +#test push +@spaces.GPU +def generate(image: Image.Image, edit_prompt: str): + # global sampler + # device = torch.device("cuda") + # dtype = torch.bfloat16 + + # if sampler is None: + # dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype) + # vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype) + # t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype) + # clip = load_clip("cpu").to(device, dtype=dtype) + + # vae.requires_grad_(False) + # t5.requires_grad_(False) + # clip.requires_grad_(False) + + # model_path = hf_hub_download( + # repo_id="Boese0601/ByteMorpher", + # filename="dit.safetensors", + # use_auth_token=os.getenv("HF_TOKEN") + # ) + # state_dict = load_file(model_path) + # dit.load_state_dict(state_dict) + # dit.eval() + + # sampler = XFluxSampler( + # clip=clip, + # t5=t5, + # ae=vae, + # model=dit, + # device=device, + # ip_loaded=False, + # spatial_condition=False, + # clip_image_processor=None, + # image_encoder=None, + # improj=None + # ) + + img = image_resize(image, 512) + w, h = img.size + img = img.resize(((w // 32) * 32, (h // 32) * 32)) + img = torch.from_numpy((np.array(img) / 127.5) - 1) + img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype) + + with torch.no_grad(): + result = sampler( + prompt=edit_prompt, + width=args.sample_width, + height=args.sample_height, + num_steps=args.sample_steps, + image_prompt=None, + true_gs=args.cfg_scale, + seed=args.seed, + ip_scale=args.ip_scale if args.use_ip else 1.0, + source_image=img if args.use_spatial_condition else None, + ) + return tensor_to_pil_image(result) + +def get_samples(): + sample_list = [ + { + "image": "assets/0_camera_zoom/20486354.png", + "edit_prompt": "Zoom in on the coral and add a small blue fish in the background.", + }, + ] + return [ + [ + Image.open(sample["image"]).resize((512, 512)), + sample["edit_prompt"], + ] + for sample in sample_list + ] + +header = """ +# ByteMorph + +
+arXiv +HuggingFace +GitHub +
""" -with gr.Blocks(css=css) as demo: - with gr.Column(elem_id="col-container"): - gr.Markdown(" # Text-to-Image Gradio Template") - - with gr.Row(): - prompt = gr.Text( - label="Prompt", - show_label=False, - max_lines=1, - placeholder="Enter your prompt", - container=False, - ) - - run_button = gr.Button("Run", scale=0, variant="primary") - - result = gr.Image(label="Result", show_label=False) - - with gr.Accordion("Advanced Settings", open=False): - negative_prompt = gr.Text( - label="Negative prompt", - max_lines=1, - placeholder="Enter a negative prompt", - visible=False, - ) - - seed = gr.Slider( - label="Seed", - minimum=0, - maximum=MAX_SEED, - step=1, - value=0, - ) - - randomize_seed = gr.Checkbox(label="Randomize seed", value=True) - - with gr.Row(): - width = gr.Slider( - label="Width", - minimum=256, - maximum=MAX_IMAGE_SIZE, - step=32, - value=1024, # Replace with defaults that work for your model - ) - - height = gr.Slider( - label="Height", - minimum=256, - maximum=MAX_IMAGE_SIZE, - step=32, - value=1024, # Replace with defaults that work for your model +def create_app(): + with gr.Blocks() as app: + gr.Markdown(header, elem_id="header") + with gr.Row(equal_height=False): + with gr.Column(variant="panel", elem_classes="inputPanel"): + original_image = gr.Image( + type="pil", label="Condition Image", width=300, elem_id="input" ) + edit_prompt = gr.Textbox(lines=2, label="Edit Prompt", elem_id="edit_prompt") + submit_btn = gr.Button("Run", elem_id="submit_btn") - with gr.Row(): - guidance_scale = gr.Slider( - label="Guidance scale", - minimum=0.0, - maximum=10.0, - step=0.1, - value=0.0, # Replace with defaults that work for your model - ) + with gr.Column(variant="panel", elem_classes="outputPanel"): + output_image = gr.Image(type="pil", elem_id="output") - num_inference_steps = gr.Slider( - label="Number of inference steps", - minimum=1, - maximum=50, - step=1, - value=2, # Replace with defaults that work for your model - ) + with gr.Row(): + examples = gr.Examples( + examples=get_samples(), + inputs=[original_image, edit_prompt], + label="Examples", + ) - gr.Examples(examples=examples, inputs=[prompt]) - gr.on( - triggers=[run_button.click, prompt.submit], - fn=infer, - inputs=[ - prompt, - negative_prompt, - seed, - randomize_seed, - width, - height, - guidance_scale, - num_inference_steps, - ], - outputs=[result, seed], - ) + submit_btn.click( + fn=generate, + inputs=[original_image, edit_prompt], + outputs=output_image, + ) + gr.HTML( + """ +
+ * This demo's template was modified from OminiControl. +
+ """ + ) + return app if __name__ == "__main__": - demo.launch() + create_app().launch(debug=False, share=False, ssr_mode=False) diff --git a/app_old.py b/app_old.py new file mode 100644 index 0000000000000000000000000000000000000000..c6173ff8cd43a53e23fb23fcd73c2b856f20abcd --- /dev/null +++ b/app_old.py @@ -0,0 +1,374 @@ +import gradio as gr +import torch +import spaces +from PIL import Image, ImageDraw, ImageFont +# from src.condition import Condition +from diffusers.pipelines import FluxPipeline +import numpy as np +import requests +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file +import torch.multiprocessing as mp +### +import argparse +import logging +import math +import os +import re +import random +import shutil +from contextlib import nullcontext +from pathlib import Path +from PIL import Image +import accelerate +import datasets +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from packaging import version +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor +from transformers.utils import ContextManagers +from omegaconf import OmegaConf +from copy import deepcopy +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr +from diffusers.utils import check_min_version, deprecate, make_image_grid +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module +from einops import rearrange +from src.flux.sampling import denoise, get_noise, get_schedule, prepare, unpack +from src.flux.util import (configs, load_ae, load_clip, + load_flow_model2, load_t5, save_image, tensor_to_pil_image, load_checkpoint) +from src.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor, IPDoubleStreamBlockProcessor, IPSingleStreamBlockProcessor, ImageProjModel +from src.flux.xflux_pipeline import XFluxSampler + +from image_datasets.dataset import loader, eval_image_pair_loader, image_resize + +from safetensors.torch import load_file +import json + + +# logger = get_logger(__name__, log_level="INFO") + + +def get_models(name: str, device, offload: bool, is_schnell: bool): + t5 = load_t5(device, max_length=256 if is_schnell else 512) + clip = load_clip(device) + clip.requires_grad_(False) + model = load_flow_model2(name, device="cpu") + vae = load_ae(name, device="cpu" if offload else device) + return model, vae, t5, clip + +args = OmegaConf.load("inference_configs/inference.yaml") #OmegaConf.load(parse_args()) +is_schnell = args.model_name == "flux-schnell" +set_seed(args.seed) +# logging_dir = os.path.join(args.output_dir, args.logging_dir) +device = "cuda" +dit, vae, t5, clip = get_models(name=args.model_name, device=device, offload=False, is_schnell=is_schnell) + +# # load image encoder +# ip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.getenv("CLIP_VIT")).to( +# # accelerator.device, dtype=torch.bfloat16 +# device, dtype=torch.bfloat16 +# ) +# ip_clip_image_processor = CLIPImageProcessor() + +if args.use_ip: + sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=True, spatial_condition=False, clip_image_processor=ip_clip_image_processor, image_encoder=ip_image_encoder, improj=ip_improj) +elif args.use_spatial_condition: + sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=False, spatial_condition=True, clip_image_processor=None, image_encoder=None, improj=None,share_position_embedding=args.share_position_embedding) +else: + sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=False, spatial_condition=False, clip_image_processor=None, image_encoder=None, improj=None) + + +# @spaces.GPU +def generate(image, edit_prompt): + print("hello?????????!!!!!") + + # accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + # accelerator = Accelerator( + # gradient_accumulation_steps=1, + # mixed_precision=args.mixed_precision, + # log_with=args.report_to, + # project_config=accelerator_project_config, + # ) + + # Make one log on every process with the configuration for debugging. + # logging.basicConfig( + # format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + # datefmt="%m/%d/%Y %H:%M:%S", + # level=logging.INFO, + # ) + # logger.info(accelerator.state, main_process_only=False) + # if accelerator.is_local_main_process: + # datasets.utils.logging.set_verbosity_warning() + # transformers.utils.logging.set_verbosity_warning() + # diffusers.utils.logging.set_verbosity_info() + # else: + # datasets.utils.logging.set_verbosity_error() + # transformers.utils.logging.set_verbosity_error() + # diffusers.utils.logging.set_verbosity_error() + + + # if accelerator.is_main_process: + # if args.output_dir is not None: + # os.makedirs(args.output_dir, exist_ok=True) + # gpt_eval_path = os.path.join(args.output_dir,"Eval") + # os.makedirs(gpt_eval_path, exist_ok=True) + + # dit, vae, t5, clip = get_models(name=args.model_name, device=accelerator.device, offload=False, is_schnell=is_schnell) + # dit, vae, t5, clip = get_models(name=args.model_name, device=device, offload=False, is_schnell=is_schnell) + + if args.use_lora: + lora_attn_procs = {} + if args.use_ip: + ip_attn_procs = {} + if args.double_blocks is None: + double_blocks_idx = list(range(19)) + else: + double_blocks_idx = [int(idx) for idx in args.double_blocks.split(",")] + + if args.single_blocks is None: + single_blocks_idx = list(range(38)) + elif args.single_blocks is not None: + single_blocks_idx = [int(idx) for idx in args.single_blocks.split(",")] + + if args.use_lora: + for name, attn_processor in dit.attn_processors.items(): + match = re.search(r'\.(\d+)\.', name) + if match: + layer_index = int(match.group(1)) + + if name.startswith("double_blocks") and layer_index in double_blocks_idx: + # if accelerator.is_main_process: + # print("setting LoRA Processor for", name) + lora_attn_procs[name] = DoubleStreamBlockLoraProcessor( + dim=3072, rank=args.rank + ) + elif name.startswith("single_blocks") and layer_index in single_blocks_idx: + # if accelerator.is_main_process: + # print("setting LoRA Processor for", name) + lora_attn_procs[name] = SingleStreamBlockLoraProcessor( + dim=3072, rank=args.rank + ) + else: + lora_attn_procs[name] = attn_processor + + dit.set_attn_processor(lora_attn_procs) + + # if args.use_ip: + # # unpack checkpoint + # checkpoint = load_checkpoint(args.ip_local_path, args.ip_repo_id, args.ip_name) + # prefix = "double_blocks." + # # blocks = {} + # proj = {} + + # for key, value in checkpoint.items(): + # # if key.startswith(prefix): + # # blocks[key[len(prefix):].replace('.processor.', '.')] = value + # if key.startswith("ip_adapter_proj_model"): + # proj[key[len("ip_adapter_proj_model."):]] = value + + # # # load image encoder + # # ip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.getenv("CLIP_VIT")).to( + # # # accelerator.device, dtype=torch.bfloat16 + # # device, dtype=torch.bfloat16 + # # ) + # # ip_clip_image_processor = CLIPImageProcessor() + + # # setup image embedding projection model + # ip_improj = ImageProjModel(4096, 768, 4) + # ip_improj.load_state_dict(proj) + # # ip_improj = ip_improj.to(accelerator.device, dtype=torch.bfloat16) + # ip_improj = ip_improj.to(device, dtype=torch.bfloat16) + + # ip_attn_procs = {} + + # for name, _ in dit.attn_processors.items(): + # ip_state_dict = {} + # for k in checkpoint.keys(): + # if name in k: + # ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k] + # if ip_state_dict: + # ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072) + # ip_attn_procs[name].load_state_dict(ip_state_dict) + # ip_attn_procs[name].to(accelerator.device, dtype=torch.bfloat16) + # else: + # ip_attn_procs[name] = dit.attn_processors[name] + # dit.set_attn_processor(ip_attn_procs) + + + vae.requires_grad_(False) + t5.requires_grad_(False) + clip.requires_grad_(False) + + + + # weight_dtype = torch.float32 + # if accelerator.mixed_precision == "fp16": + # weight_dtype = torch.float16 + # args.mixed_precision = accelerator.mixed_precision + # elif accelerator.mixed_precision == "bf16": + # weight_dtype = torch.bfloat16 + # args.mixed_precision = accelerator.mixed_precision + + + # print(f"Resuming from checkpoint {args.ckpt_dir}") + # dit_stat_dict = load_file(args.ckpt_dir) + # Get path from Hub + model_path = hf_hub_download( + repo_id="Boese0601/ByteMorpher", + filename="dit.safetensors" + ) + state_dict = load_file(model_path) + dit.load_state_dict(state_dict) + dit = dit.to(weight_dtype) + dit.eval() + + # test_dataloader = loader(**args.data_config) + test_dataloader = eval_image_pair_loader(**args.data_config) + + + + # from deepspeed import initialize + dit = accelerator.prepare(dit) + + # if accelerator.is_main_process: + # accelerator.init_trackers(args.tracker_project_name, {"test": None}) + + # logger.info("***** Running Evaluation *****") + # logger.info(f" Instantaneous batch size = {args.eval_batch_size}") + + + + # progress_bar = tqdm( + # range(0, len(test_dataloader)), + # initial=0, + # desc="Steps", + # disable=not accelerator.is_local_main_process, + # ) + + # for step, batch in enumerate(test_dataloader): + # with accelerator.accumulate(dit): + # img, tgt_image, prompt, edit_prompt, img_name, edit_name = batch + img = image_resize(image, 512) + w, h = img.size + new_w = (w // 32) * 32 + new_h = (h // 32) * 32 + img = img.resize((new_w, new_h)) + img = torch.from_numpy((np.array(img) / 127.5) - 1) + img = img.permute(2, 0, 1).unsqueeze(0) + + edit_prompt = edit_prompt + + # if args.use_ip: + # sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=True, spatial_condition=False, clip_image_processor=ip_clip_image_processor, image_encoder=ip_image_encoder, improj=ip_improj) + # elif args.use_spatial_condition: + # sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=False, spatial_condition=True, clip_image_processor=None, image_encoder=None, improj=None,share_position_embedding=args.share_position_embedding) + # else: + # sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=False, spatial_condition=False, clip_image_processor=None, image_encoder=None, improj=None) + with torch.no_grad(): + result = sampler(prompt=edit_prompt, + width=args.sample_width, + height=args.sample_height, + num_steps=args.sample_steps, + image_prompt=None, # ip_adapter + true_gs=args.cfg_scale, + seed=args.seed, + ip_scale=args.ip_scale if args.use_ip else 1.0, + source_image=img if args.use_spatial_condition else None, + ) + gen_img = result + + + + # progress_bar.update(1) + + # accelerator.wait_for_everyone() + # accelerator.end_training() + return gen_img + + +def get_samples(): + sample_list = [ + { + "image": "assets/0_camera_zoom/20486354.png", + "edit_prompt": "Zoom in on the coral and add a small blue fish in the background.", + }, + ] + return [ + [ + Image.open(sample["image"]).resize((512, 512)), + sample["edit_prompt"], + ] + for sample in sample_list + ] + + +header = """ +# ByteMoprh + +
+arXiv +HuggingFace +GitHub +
+""" + + +def create_app(): + with gr.Blocks() as app: + gr.Markdown(header, elem_id="header") + with gr.Row(equal_height=False): + with gr.Column(variant="panel", elem_classes="inputPanel"): + original_image = gr.Image( + type="pil", label="Condition Image", width=300, elem_id="input" + ) + edit_prompt = gr.Textbox(lines=2, label="Edit Prompt", elem_id="edit_prompt") + submit_btn = gr.Button("Run", elem_id="submit_btn") + + with gr.Column(variant="panel", elem_classes="outputPanel"): + output_image = gr.Image(type="pil", elem_id="output") + + with gr.Row(): + examples = gr.Examples( + examples=get_samples(), + inputs=[original_image, edit_prompt], + label="Examples", + ) + + submit_btn.click( + fn=generate, + inputs=[original_image, edit_prompt], + outputs=output_image, + ) + gr.HTML( + """ +
+ * This demo's template was modified from OminiControl. +
+ """ + ) + return app + + +if __name__ == "__main__": + print("CUDA available:", torch.cuda.is_available()) + print("CUDA version:", torch.version.cuda) + print("GPU device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None") + # mp.set_start_method("spawn", force=True) + create_app().launch(debug=False, share=True, ssr_mode=False) diff --git a/image_datasets/.DS_Store b/image_datasets/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/image_datasets/.DS_Store differ diff --git a/image_datasets/dataset.py b/image_datasets/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7cf2ee50521c94db52b04955c6e5215912737f07 --- /dev/null +++ b/image_datasets/dataset.py @@ -0,0 +1,231 @@ +import os +import pandas as pd +import numpy as np +from PIL import Image +import torch +from torch.utils.data import Dataset, DataLoader +import json +import random +import glob +import torch +import torchvision.transforms.functional as TF + + + +def image_resize(img, max_size=512): + w, h = img.size + if w >= h: + new_w = max_size + new_h = int((max_size / w) * h) + else: + new_h = max_size + new_w = int((max_size / h) * w) + return img.resize((new_w, new_h)) + +def c_crop(image): + width, height = image.size + new_size = min(width, height) + left = (width - new_size) / 2 + top = (height - new_size) / 2 + right = (width + new_size) / 2 + bottom = (height + new_size) / 2 + return image.crop((left, top, right, bottom)) + +def crop_to_aspect_ratio(image, ratio="16:9"): + width, height = image.size + ratio_map = { + "16:9": (16, 9), + "4:3": (4, 3), + "1:1": (1, 1) + } + target_w, target_h = ratio_map[ratio] + target_ratio_value = target_w / target_h + + current_ratio = width / height + + if current_ratio > target_ratio_value: + new_width = int(height * target_ratio_value) + offset = (width - new_width) // 2 + crop_box = (offset, 0, offset + new_width, height) + else: + new_height = int(width / target_ratio_value) + offset = (height - new_height) // 2 + crop_box = (0, offset, width, offset + new_height) + + cropped_img = image.crop(crop_box) + return cropped_img + + +class CustomImageDataset(Dataset): + def __init__(self, img_dir, img_size=512, caption_type='json', random_ratio=False): + self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i] + # self.images = glob.glob(img_dir +'**/*.jpg', recursive=True) + glob.glob(img_dir +'**/*.png', recursive=True) + glob.glob(img_dir +'**/*.jpeg', recursive=True) + self.images.sort() + self.img_size = img_size + self.caption_type = caption_type + self.random_ratio = random_ratio + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + try: + img = Image.open(self.images[idx]).convert('RGB') + + if self.random_ratio: + ratio = random.choice(["16:9", "default", "1:1", "4:3"]) + if ratio != "default": + img = crop_to_aspect_ratio(img, ratio) + img = image_resize(img, self.img_size) + w, h = img.size + new_w = (w // 32) * 32 + new_h = (h // 32) * 32 + img = img.resize((new_w, new_h)) + img = torch.from_numpy((np.array(img) / 127.5) - 1) + img = img.permute(2, 0, 1) + json_path = self.images[idx].split('.')[0] + '.' + self.caption_type + if self.caption_type == "json": + prompt = json.load(open(json_path))['caption'] + else: + prompt = open(json_path).read() + return img, prompt + except Exception as e: + print(e) + return self.__getitem__(random.randint(0, len(self.images) - 1)) + + +def loader(train_batch_size, num_workers, **args): + dataset = CustomImageDataset(**args) + return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True) + + + +class ImageEditPairDataset(Dataset): + def __init__(self, img_dir, img_size=512, caption_type='json', random_ratio=False, grayscale_editing=False, zoom_camera=False): + # self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i] + self.images = glob.glob(img_dir +'**/*.jpg', recursive=True) + glob.glob(img_dir +'**/*.png', recursive=True) + glob.glob(img_dir +'**/*.jpeg', recursive=True) + self.images.sort() + self.img_size = img_size + self.caption_type = caption_type + self.random_ratio = random_ratio + self.grayscale_editing = grayscale_editing + self.zoom_camera = zoom_camera + if "ByteMorph-Bench" or "InstructMove" in img_dir: + self.eval = True + else: + self.eval = False + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + try: + img = Image.open(self.images[idx]).convert('RGB') + ori_width, ori_height = img.size + left_half = (0, 0, ori_width // 2, ori_height) + right_half = (ori_width // 2, 0, ori_width, ori_height) + src_image = img.crop(left_half) # Left half + tgt_image = img.crop(right_half) # Right half + # print("ori_width, ori_height: ",ori_width, ori_height) + if self.random_ratio: + ratio = random.choice(["16:9", "default", "1:1", "4:3"]) + if ratio != "default": + src_image = crop_to_aspect_ratio(src_image, ratio) + tgt_image = crop_to_aspect_ratio(tgt_image, ratio) + src_image = image_resize(src_image, self.img_size) + tgt_image = image_resize(tgt_image, self.img_size) + w, h = src_image.size + new_w = (w // 32) * 32 + new_h = (h // 32) * 32 + # print("new_w, new_h: ",new_w, new_h) + src_image = src_image.resize((new_w, new_h)) + src_image = torch.from_numpy((np.array(src_image) / 127.5) - 1) + src_image = src_image.permute(2, 0, 1) + tgt_image = tgt_image.resize((new_w, new_h)) + tgt_image = torch.from_numpy((np.array(tgt_image) / 127.5) - 1) + tgt_image = tgt_image.permute(2, 0, 1) + json_path = self.images[idx].split('.')[0] + '.' + self.caption_type + if self.eval: + image_name = self.images[idx].split('.')[0].split("/")[-1] + edit_type = self.images[idx].split('.')[0].split("/")[-2] + if self.caption_type == "json": + if not self.eval: + prompt = json.load(open(json_path))['caption'] + edit_prompt = json.load(open(json_path))['edit'] + else: + prompt = [] #json.load(open(json_path))['caption'] + edit_prompt = json.load(open(json_path))['edit'] + else: + raise NotImplementedError + # prompt = open(json_path).read() + if (not self.grayscale_editing) and (not self.zoom_camera): + if not self.eval: + return src_image, tgt_image, prompt, edit_prompt + else: + return src_image, tgt_image, prompt, edit_prompt, image_name, edit_type + if self.grayscale_editing and (not self.zoom_camera): + # Grayscale = 0.2989 * R + 0.5870 * G + 0.1140 * B + grayscale_image = 0.2989 * src_image[0, :, :] + 0.5870 * src_image[1, :, :] + 0.1140 * src_image[2, :, :] + tgt_image = grayscale_image.unsqueeze(0).repeat(3, 1, 1) + edit_prompt = "Convert the input image to a black and white grayscale image while maintaining the original composition and details." + if not self.eval: + return src_image, tgt_image, prompt, edit_prompt + else: + return src_image, tgt_image, prompt, edit_prompt, image_name, edit_type + if (not self.grayscale_editing) and self.zoom_camera: + cropped = TF.center_crop(src_image, (256, 256)) + tgt_image = TF.resize(cropped, (512, 512)) + edit_prompt = "The central area of the input image is zoomed. The camera transitions from a wide shot to a closer position, narrowing its view." + if not self.eval: + return src_image, tgt_image, prompt, edit_prompt + else: + return src_image, tgt_image, prompt, edit_prompt, image_name, edit_type + if self.grayscale_editing and self.zoom_camera: + grayscale_image = 0.2989 * src_image[0, :, :] + 0.5870 * src_image[1, :, :] + 0.1140 * src_image[2, :, :] + tgt_image = grayscale_image.unsqueeze(0).repeat(3, 1, 1) + tgt_image = TF.center_crop(tgt_image, (256, 256)) + tgt_image = TF.resize(tgt_image, (512, 512)) + edit_prompt = "Convert the input image to a black and white grayscale image while maintaining the original composition and details. And the central area of the input image is zoomed, the camera transitions from a wide shot to a closer position, narrowing its view." + if not self.eval: + return src_image, tgt_image, prompt, edit_prompt + else: + return src_image, tgt_image, prompt, edit_prompt, image_name, edit_type + except Exception as e: + print(e) + return self.__getitem__(random.randint(0, len(self.images) - 1)) + + +def image_pair_loader(train_batch_size, num_workers, **args): + dataset = ImageEditPairDataset(**args) + return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True) + +def eval_image_pair_loader(eval_batch_size, num_workers, **args): + dataset = ImageEditPairDataset(**args) + return DataLoader(dataset, batch_size=eval_batch_size, num_workers=num_workers, shuffle=False) + + + +if __name__ == "__main__": + from src.flux.util import save_image + example_dataset = ImageEditPairDataset( + img_dir="", + img_size=512, + caption_type='json', + random_ratio=False, + grayscale_editing=False, + zoom_camera=False, + ) + + train_dataloader = DataLoader( + example_dataset, + batch_size=1, + num_workers=4, + shuffle=False, + ) + + for step, batch in enumerate(train_dataloader): + src_image, tgt_image, prompt, edit_prompt = batch + os.makedirs("./debug", exist_ok=True) + save_image(src_image, f"./debug/{step}-src_img.jpg") + save_image(tgt_image, f"./debug/{step}-tgt_img.jpg") + if step == 3: + breakpoint() diff --git a/inference_configs/inference.yaml b/inference_configs/inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..189e02a0ef559fc65a2ce2f995cd9b008fd7fad8 --- /dev/null +++ b/inference_configs/inference.yaml @@ -0,0 +1,33 @@ +model_name: "flux-dev" +use_spatial_condition: true +share_position_embedding: true +use_share_weight_referencenet: false +use_ip: false +ip_local_path: null +ip_repo_id: null +ip_name: null +ip_scale: 1.0 +use_lora: false +data_config: + eval_batch_size: 1 + num_workers: 0 + img_size: 512 + img_dir: output_bench/ #./ByteMorph-Bench/ + grayscale_editing: false + zoom_camera: false + random_ratio: false +report_to: wandb +eval_batch_size: 1 +ckpt_dir: ./pretrained_weights/ByteMorpher/dit.safetensors +output_dir: ./test_log/seedmorpher/ +logging_dir: logs +mixed_precision: "bf16" +rank: 16 +single_blocks: null +double_blocks: null +disable_sampling: false +sample_width: 512 +sample_height: 512 +sample_steps: 25 +seed: 42 +cfg_scale: 3.5 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 73d01db64bd054c7d21fd0bd79b3af087c468809..233e8b8c2734a5b326ef3211b083ef423b7777af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,26 @@ -accelerate +# --extra-index-url https://download.pytorch.org/whl/cu124 +# torch==2.6.0 +# torchvision==0.21.0 +# torchaudio==2.6.0 + +gradio>=4.0 +accelerate==0.30.1 +deepspeed==0.14.4 +einops==0.8.0 +transformers==4.43.3 +huggingface-hub==0.24.5 +optimum-quanto +datasets +omegaconf diffusers -invisible_watermark -torch -transformers -xformers \ No newline at end of file +sentencepiece +opencv-python +matplotlib +onnxruntime +timm +wandb + +setuptools +wheel + + diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..36ef5befcbe5f66fbe8c9f332bcae89cde6265d6 Binary files /dev/null and b/src/.DS_Store differ diff --git a/src/flux/.DS_Store b/src/flux/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/src/flux/.DS_Store differ diff --git a/src/flux/__init__.py b/src/flux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43c365a49d6980e88acba10ef3069f110a59644a --- /dev/null +++ b/src/flux/__init__.py @@ -0,0 +1,11 @@ +try: + from ._version import version as __version__ # type: ignore + from ._version import version_tuple +except ImportError: + __version__ = "unknown (no version information available)" + version_tuple = (0, 0, "unknown", "noinfo") + +from pathlib import Path + +PACKAGE = __package__.replace("_", "-") +PACKAGE_ROOT = Path(__file__).parent diff --git a/src/flux/__main__.py b/src/flux/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..d5cf0fd2444d4cda4053fa74dad3371556b886e5 --- /dev/null +++ b/src/flux/__main__.py @@ -0,0 +1,4 @@ +from .cli import app + +if __name__ == "__main__": + app() diff --git a/src/flux/annotator/canny/__init__.py b/src/flux/annotator/canny/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb0da951dc838ec9dec2131007e036113281800b --- /dev/null +++ b/src/flux/annotator/canny/__init__.py @@ -0,0 +1,6 @@ +import cv2 + + +class CannyDetector: + def __call__(self, img, low_threshold, high_threshold): + return cv2.Canny(img, low_threshold, high_threshold) diff --git a/src/flux/annotator/ckpts/ckpts.txt b/src/flux/annotator/ckpts/ckpts.txt new file mode 100644 index 0000000000000000000000000000000000000000..1978551fb2a9226814eaf58459f414fcfac4e69b --- /dev/null +++ b/src/flux/annotator/ckpts/ckpts.txt @@ -0,0 +1 @@ +Weights here. \ No newline at end of file diff --git a/src/flux/annotator/dwpose/__init__.py b/src/flux/annotator/dwpose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb6e172d05c9de3f1cdd61e330ad8d6dde46dfdd --- /dev/null +++ b/src/flux/annotator/dwpose/__init__.py @@ -0,0 +1,68 @@ +# Openpose +# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose +# 2nd Edited by https://github.com/Hzzone/pytorch-openpose +# 3rd Edited by ControlNet +# 4th Edited by ControlNet (added face and correct hands) + +import os +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + +import torch +import numpy as np +from . import util +from .wholebody import Wholebody + +def draw_pose(pose, H, W): + bodies = pose['bodies'] + faces = pose['faces'] + hands = pose['hands'] + candidate = bodies['candidate'] + subset = bodies['subset'] + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) + + canvas = util.draw_bodypose(canvas, candidate, subset) + + canvas = util.draw_handpose(canvas, hands) + + canvas = util.draw_facepose(canvas, faces) + + return canvas + + +class DWposeDetector: + def __init__(self, device): + + self.pose_estimation = Wholebody(device) + + def __call__(self, oriImg): + oriImg = oriImg.copy() + H, W, C = oriImg.shape + with torch.no_grad(): + candidate, subset = self.pose_estimation(oriImg) + nums, keys, locs = candidate.shape + candidate[..., 0] /= float(W) + candidate[..., 1] /= float(H) + body = candidate[:,:18].copy() + body = body.reshape(nums*18, locs) + score = subset[:,:18] + for i in range(len(score)): + for j in range(len(score[i])): + if score[i][j] > 0.3: + score[i][j] = int(18*i+j) + else: + score[i][j] = -1 + + un_visible = subset<0.3 + candidate[un_visible] = -1 + + foot = candidate[:,18:24] + + faces = candidate[:,24:92] + + hands = candidate[:,92:113] + hands = np.vstack([hands, candidate[:,113:]]) + + bodies = dict(candidate=body, subset=score) + pose = dict(bodies=bodies, hands=hands, faces=faces) + + return draw_pose(pose, H, W) diff --git a/src/flux/annotator/dwpose/onnxdet.py b/src/flux/annotator/dwpose/onnxdet.py new file mode 100644 index 0000000000000000000000000000000000000000..e0411c96a5eef41e981bde5481ef7786b242f1fa --- /dev/null +++ b/src/flux/annotator/dwpose/onnxdet.py @@ -0,0 +1,125 @@ +import cv2 +import numpy as np + +import onnxruntime + +def nms(boxes, scores, nms_thr): + """Single class NMS implemented in Numpy.""" + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= nms_thr)[0] + order = order[inds + 1] + + return keep + +def multiclass_nms(boxes, scores, nms_thr, score_thr): + """Multiclass NMS implemented in Numpy. Class-aware version.""" + final_dets = [] + num_classes = scores.shape[1] + for cls_ind in range(num_classes): + cls_scores = scores[:, cls_ind] + valid_score_mask = cls_scores > score_thr + if valid_score_mask.sum() == 0: + continue + else: + valid_scores = cls_scores[valid_score_mask] + valid_boxes = boxes[valid_score_mask] + keep = nms(valid_boxes, valid_scores, nms_thr) + if len(keep) > 0: + cls_inds = np.ones((len(keep), 1)) * cls_ind + dets = np.concatenate( + [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 + ) + final_dets.append(dets) + if len(final_dets) == 0: + return None + return np.concatenate(final_dets, 0) + +def demo_postprocess(outputs, img_size, p6=False): + grids = [] + expanded_strides = [] + strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] + + hsizes = [img_size[0] // stride for stride in strides] + wsizes = [img_size[1] // stride for stride in strides] + + for hsize, wsize, stride in zip(hsizes, wsizes, strides): + xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) + grid = np.stack((xv, yv), 2).reshape(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + expanded_strides.append(np.full((*shape, 1), stride)) + + grids = np.concatenate(grids, 1) + expanded_strides = np.concatenate(expanded_strides, 1) + outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides + outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides + + return outputs + +def preprocess(img, input_size, swap=(2, 0, 1)): + if len(img.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 + else: + padded_img = np.ones(input_size, dtype=np.uint8) * 114 + + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + +def inference_detector(session, oriImg): + input_shape = (640,640) + img, ratio = preprocess(oriImg, input_shape) + + ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} + output = session.run(None, ort_inputs) + predictions = demo_postprocess(output[0], input_shape)[0] + + boxes = predictions[:, :4] + scores = predictions[:, 4:5] * predictions[:, 5:] + + boxes_xyxy = np.ones_like(boxes) + boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. + boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. + boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. + boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. + boxes_xyxy /= ratio + dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) + if dets is not None: + final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] + isscore = final_scores>0.3 + iscat = final_cls_inds == 0 + isbbox = [ i and j for (i, j) in zip(isscore, iscat)] + final_boxes = final_boxes[isbbox] + else: + final_boxes = np.array([]) + + return final_boxes diff --git a/src/flux/annotator/dwpose/onnxpose.py b/src/flux/annotator/dwpose/onnxpose.py new file mode 100644 index 0000000000000000000000000000000000000000..79cd4a06241123af81ea22446a4ca8816716443f --- /dev/null +++ b/src/flux/annotator/dwpose/onnxpose.py @@ -0,0 +1,360 @@ +from typing import List, Tuple + +import cv2 +import numpy as np +import onnxruntime as ort + +def preprocess( + img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256) +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Do preprocessing for RTMPose model inference. + + Args: + img (np.ndarray): Input image in shape. + input_size (tuple): Input image size in shape (w, h). + + Returns: + tuple: + - resized_img (np.ndarray): Preprocessed image. + - center (np.ndarray): Center of image. + - scale (np.ndarray): Scale of image. + """ + # get shape of image + img_shape = img.shape[:2] + out_img, out_center, out_scale = [], [], [] + if len(out_bbox) == 0: + out_bbox = [[0, 0, img_shape[1], img_shape[0]]] + for i in range(len(out_bbox)): + x0 = out_bbox[i][0] + y0 = out_bbox[i][1] + x1 = out_bbox[i][2] + y1 = out_bbox[i][3] + bbox = np.array([x0, y0, x1, y1]) + + # get center and scale + center, scale = bbox_xyxy2cs(bbox, padding=1.25) + + # do affine transformation + resized_img, scale = top_down_affine(input_size, scale, center, img) + + # normalize image + mean = np.array([123.675, 116.28, 103.53]) + std = np.array([58.395, 57.12, 57.375]) + resized_img = (resized_img - mean) / std + + out_img.append(resized_img) + out_center.append(center) + out_scale.append(scale) + + return out_img, out_center, out_scale + + +def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray: + """Inference RTMPose model. + + Args: + sess (ort.InferenceSession): ONNXRuntime session. + img (np.ndarray): Input image in shape. + + Returns: + outputs (np.ndarray): Output of RTMPose model. + """ + all_out = [] + # build input + for i in range(len(img)): + input = [img[i].transpose(2, 0, 1)] + + # build output + sess_input = {sess.get_inputs()[0].name: input} + sess_output = [] + for out in sess.get_outputs(): + sess_output.append(out.name) + + # run model + outputs = sess.run(sess_output, sess_input) + all_out.append(outputs) + + return all_out + + +def postprocess(outputs: List[np.ndarray], + model_input_size: Tuple[int, int], + center: Tuple[int, int], + scale: Tuple[int, int], + simcc_split_ratio: float = 2.0 + ) -> Tuple[np.ndarray, np.ndarray]: + """Postprocess for RTMPose model output. + + Args: + outputs (np.ndarray): Output of RTMPose model. + model_input_size (tuple): RTMPose model Input image size. + center (tuple): Center of bbox in shape (x, y). + scale (tuple): Scale of bbox in shape (w, h). + simcc_split_ratio (float): Split ratio of simcc. + + Returns: + tuple: + - keypoints (np.ndarray): Rescaled keypoints. + - scores (np.ndarray): Model predict scores. + """ + all_key = [] + all_score = [] + for i in range(len(outputs)): + # use simcc to decode + simcc_x, simcc_y = outputs[i] + keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio) + + # rescale keypoints + keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2 + all_key.append(keypoints[0]) + all_score.append(scores[0]) + + return np.array(all_key), np.array(all_score) + + +def bbox_xyxy2cs(bbox: np.ndarray, + padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]: + """Transform the bbox format from (x,y,w,h) into (center, scale) + + Args: + bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted + as (left, top, right, bottom) + padding (float): BBox padding factor that will be multilied to scale. + Default: 1.0 + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or + (n, 2) + - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or + (n, 2) + """ + # convert single bbox from (4, ) to (1, 4) + dim = bbox.ndim + if dim == 1: + bbox = bbox[None, :] + + # get bbox center and scale + x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) + center = np.hstack([x1 + x2, y1 + y2]) * 0.5 + scale = np.hstack([x2 - x1, y2 - y1]) * padding + + if dim == 1: + center = center[0] + scale = scale[0] + + return center, scale + + +def _fix_aspect_ratio(bbox_scale: np.ndarray, + aspect_ratio: float) -> np.ndarray: + """Extend the scale to match the given aspect ratio. + + Args: + scale (np.ndarray): The image scale (w, h) in shape (2, ) + aspect_ratio (float): The ratio of ``w/h`` + + Returns: + np.ndarray: The reshaped image scale in (2, ) + """ + w, h = np.hsplit(bbox_scale, [1]) + bbox_scale = np.where(w > h * aspect_ratio, + np.hstack([w, w / aspect_ratio]), + np.hstack([h * aspect_ratio, h])) + return bbox_scale + + +def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray: + """Rotate a point by an angle. + + Args: + pt (np.ndarray): 2D point coordinates (x, y) in shape (2, ) + angle_rad (float): rotation angle in radian + + Returns: + np.ndarray: Rotated point in shape (2, ) + """ + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + rot_mat = np.array([[cs, -sn], [sn, cs]]) + return rot_mat @ pt + + +def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + + Args: + a (np.ndarray): The 1st point (x,y) in shape (2, ) + b (np.ndarray): The 2nd point (x,y) in shape (2, ) + + Returns: + np.ndarray: The 3rd point. + """ + direction = a - b + c = b + np.r_[-direction[1], direction[0]] + return c + + +def get_warp_matrix(center: np.ndarray, + scale: np.ndarray, + rot: float, + output_size: Tuple[int, int], + shift: Tuple[float, float] = (0., 0.), + inv: bool = False) -> np.ndarray: + """Calculate the affine transformation matrix that can warp the bbox area + in the input image to the output size. + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ] | list(2,)): Size of the + destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + + Returns: + np.ndarray: A 2x3 transformation matrix + """ + shift = np.array(shift) + src_w = scale[0] + dst_w = output_size[0] + dst_h = output_size[1] + + # compute transformation matrix + rot_rad = np.deg2rad(rot) + src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + # get four corners of the src rectangle in the original image + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale * shift + src[1, :] = center + src_dir + scale * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + # get four corners of the dst rectangle in the input image + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return warp_mat + + +def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict, + img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get the bbox image as the model input by affine transform. + + Args: + input_size (dict): The input size of the model. + bbox_scale (dict): The bbox scale of the img. + bbox_center (dict): The bbox center of the img. + img (np.ndarray): The original image. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: img after affine transform. + - np.ndarray[float32]: bbox scale after affine transform. + """ + w, h = input_size + warp_size = (int(w), int(h)) + + # reshape bbox to fixed aspect ratio + bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h) + + # get the affine matrix + center = bbox_center + scale = bbox_scale + rot = 0 + warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) + + # do affine transform + img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) + + return img, bbox_scale + + +def get_simcc_maximum(simcc_x: np.ndarray, + simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get maximum response location and value from simcc representations. + + Note: + instance number: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) + simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) + + Returns: + tuple: + - locs (np.ndarray): locations of maximum heatmap responses in shape + (K, 2) or (N, K, 2) + - vals (np.ndarray): values of maximum heatmap responses in shape + (K,) or (N, K) + """ + N, K, Wx = simcc_x.shape + simcc_x = simcc_x.reshape(N * K, -1) + simcc_y = simcc_y.reshape(N * K, -1) + + # get maximum value locations + x_locs = np.argmax(simcc_x, axis=1) + y_locs = np.argmax(simcc_y, axis=1) + locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) + max_val_x = np.amax(simcc_x, axis=1) + max_val_y = np.amax(simcc_y, axis=1) + + # get maximum value across x and y axis + mask = max_val_x > max_val_y + max_val_x[mask] = max_val_y[mask] + vals = max_val_x + locs[vals <= 0.] = -1 + + # reshape + locs = locs.reshape(N, K, 2) + vals = vals.reshape(N, K) + + return locs, vals + + +def decode(simcc_x: np.ndarray, simcc_y: np.ndarray, + simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]: + """Modulate simcc distribution with Gaussian. + + Args: + simcc_x (np.ndarray[K, Wx]): model predicted simcc in x. + simcc_y (np.ndarray[K, Wy]): model predicted simcc in y. + simcc_split_ratio (int): The split ratio of simcc. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2) + - np.ndarray[float32]: scores in shape (K,) or (n, K) + """ + keypoints, scores = get_simcc_maximum(simcc_x, simcc_y) + keypoints /= simcc_split_ratio + + return keypoints, scores + + +def inference_pose(session, out_bbox, oriImg): + h, w = session.get_inputs()[0].shape[2:] + model_input_size = (w, h) + resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size) + outputs = inference(session, resized_img) + keypoints, scores = postprocess(outputs, model_input_size, center, scale) + + return keypoints, scores \ No newline at end of file diff --git a/src/flux/annotator/dwpose/util.py b/src/flux/annotator/dwpose/util.py new file mode 100644 index 0000000000000000000000000000000000000000..73d7d0153b38d143eb8090e07a9784a274b619ed --- /dev/null +++ b/src/flux/annotator/dwpose/util.py @@ -0,0 +1,297 @@ +import math +import numpy as np +import matplotlib +import cv2 + + +eps = 0.01 + + +def smart_resize(x, s): + Ht, Wt = s + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) + + +def smart_resize_k(x, fx, fy): + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + Ht, Wt = Ho * fy, Wo * fx + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + + +def draw_bodypose(canvas, candidate, subset): + H, W, C = canvas.shape + candidate = np.array(candidate) + subset = np.array(subset) + + stickwidth = 4 + + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for i in range(17): + for n in range(len(subset)): + index = subset[n][np.array(limbSeq[i]) - 1] + if -1 in index: + continue + Y = candidate[index.astype(int), 0] * float(W) + X = candidate[index.astype(int), 1] * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, colors[i]) + + canvas = (canvas * 0.6).astype(np.uint8) + + for i in range(18): + for n in range(len(subset)): + index = int(subset[n][i]) + if index == -1: + continue + x, y = candidate[index][0:2] + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) + + return canvas + + +def draw_handpose(canvas, all_hand_peaks): + H, W, C = canvas.shape + + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + for peaks in all_hand_peaks: + peaks = np.array(peaks) + + for ie, e in enumerate(edges): + x1, y1 = peaks[e[0]] + x2, y2 = peaks[e[1]] + x1 = int(x1 * W) + y1 = int(y1 * H) + x2 = int(x2 * W) + y2 = int(y2 * H) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + + for i, keyponit in enumerate(peaks): + x, y = keyponit + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + return canvas + + +def draw_facepose(canvas, all_lmks): + H, W, C = canvas.shape + for lmks in all_lmks: + lmks = np.array(lmks) + for lmk in lmks: + x, y = lmk + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) + return canvas + + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(candidate, subset, oriImg): + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + # if any of three not detected + has_left = np.sum(person[[5, 6, 7]] == -1) == 0 + has_right = np.sum(person[[2, 3, 4]] == -1) == 0 + if not (has_left or has_right): + continue + hands = [] + #left hand + if has_left: + left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] + x1, y1 = candidate[left_shoulder_index][:2] + x2, y2 = candidate[left_elbow_index][:2] + x3, y3 = candidate[left_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, True]) + # right hand + if has_right: + right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] + x1, y1 = candidate[right_shoulder_index][:2] + x2, y2 = candidate[right_elbow_index][:2] + x3, y3 = candidate[right_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, False]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox + # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); + # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); + # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); + # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); + # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: x = 0 + if y < 0: y = 0 + width1 = width + width2 = width + if x + width > image_width: width1 = image_width - x + if y + width > image_height: width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append([int(x), int(y), int(width), is_left]) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + + +# Written by Lvmin +def faceDetect(candidate, subset, oriImg): + # left right eye ear 14 15 16 17 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + has_head = person[0] > -1 + if not has_head: + continue + + has_left_eye = person[14] > -1 + has_right_eye = person[15] > -1 + has_left_ear = person[16] > -1 + has_right_ear = person[17] > -1 + + if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear): + continue + + head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]] + + width = 0.0 + x0, y0 = candidate[head][:2] + + if has_left_eye: + x1, y1 = candidate[left_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_right_eye: + x1, y1 = candidate[right_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_left_ear: + x1, y1 = candidate[left_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + if has_right_ear: + x1, y1 = candidate[right_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + x, y = x0, y0 + + x -= width + y -= width + + if x < 0: + x = 0 + + if y < 0: + y = 0 + + width1 = width * 2 + width2 = width * 2 + + if x + width > image_width: + width1 = image_width - x + + if y + width > image_height: + width2 = image_height - y + + width = min(width1, width2) + + if width >= 20: + detect_result.append([int(x), int(y), int(width)]) + + return detect_result + + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j diff --git a/src/flux/annotator/dwpose/wholebody.py b/src/flux/annotator/dwpose/wholebody.py new file mode 100644 index 0000000000000000000000000000000000000000..d73f19d61c238c47cf7de98d01385b2150a5361f --- /dev/null +++ b/src/flux/annotator/dwpose/wholebody.py @@ -0,0 +1,48 @@ +import cv2 +import numpy as np + +import onnxruntime as ort +from huggingface_hub import hf_hub_download +from .onnxdet import inference_detector +from .onnxpose import inference_pose + + +class Wholebody: + def __init__(self, device="cuda:0"): + providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider'] + onnx_det = hf_hub_download("yzd-v/DWPose", "yolox_l.onnx") + onnx_pose = hf_hub_download("yzd-v/DWPose", "dw-ll_ucoco_384.onnx") + + self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) + self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) + + def __call__(self, oriImg): + det_result = inference_detector(self.session_det, oriImg) + keypoints, scores = inference_pose(self.session_pose, det_result, oriImg) + + keypoints_info = np.concatenate( + (keypoints, scores[..., None]), axis=-1) + # compute neck joint + neck = np.mean(keypoints_info[:, [5, 6]], axis=1) + # neck score when visualizing pred + neck[:, 2:4] = np.logical_and( + keypoints_info[:, 5, 2:4] > 0.3, + keypoints_info[:, 6, 2:4] > 0.3).astype(int) + new_keypoints_info = np.insert( + keypoints_info, 17, neck, axis=1) + mmpose_idx = [ + 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 + ] + openpose_idx = [ + 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 + ] + new_keypoints_info[:, openpose_idx] = \ + new_keypoints_info[:, mmpose_idx] + keypoints_info = new_keypoints_info + + keypoints, scores = keypoints_info[ + ..., :2], keypoints_info[..., 2] + + return keypoints, scores + + diff --git a/src/flux/annotator/hed/__init__.py b/src/flux/annotator/hed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70d11c9e62133149d38091a597a1b6691ff8f1b6 --- /dev/null +++ b/src/flux/annotator/hed/__init__.py @@ -0,0 +1,95 @@ +# This is an improved version and model of HED edge detection with Apache License, Version 2.0. +# Please use this implementation in your products +# This implementation may produce slightly different results from Saining Xie's official implementations, +# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations. +# Different from official models and other implementations, this is an RGB-input model (rather than BGR) +# and in this way it works better for gradio's RGB protocol + +import os +import cv2 +import torch +import numpy as np + +from huggingface_hub import hf_hub_download +from einops import rearrange +from ...annotator.util import annotator_ckpts_path + + +class DoubleConvBlock(torch.nn.Module): + def __init__(self, input_channel, output_channel, layer_number): + super().__init__() + self.convs = torch.nn.Sequential() + self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + for i in range(1, layer_number): + self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0) + + def __call__(self, x, down_sampling=False): + h = x + if down_sampling: + h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2)) + for conv in self.convs: + h = conv(h) + h = torch.nn.functional.relu(h) + return h, self.projection(h) + + +class ControlNetHED_Apache2(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1))) + self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2) + self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2) + self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3) + self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3) + self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3) + + def __call__(self, x): + h = x - self.norm + h, projection1 = self.block1(h) + h, projection2 = self.block2(h, down_sampling=True) + h, projection3 = self.block3(h, down_sampling=True) + h, projection4 = self.block4(h, down_sampling=True) + h, projection5 = self.block5(h, down_sampling=True) + return projection1, projection2, projection3, projection4, projection5 + + +class HEDdetector: + def __init__(self): + modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth") + if not os.path.exists(modelpath): + modelpath = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth") + self.netNetwork = ControlNetHED_Apache2().float().cuda().eval() + self.netNetwork.load_state_dict(torch.load(modelpath)) + + def __call__(self, input_image): + assert input_image.ndim == 3 + H, W, C = input_image.shape + with torch.no_grad(): + image_hed = torch.from_numpy(input_image.copy()).float().cuda() + image_hed = rearrange(image_hed, 'h w c -> 1 c h w') + edges = self.netNetwork(image_hed) + edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] + edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges] + edges = np.stack(edges, axis=2) + edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) + edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + return edge + + +def nms(x, t, s): + x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) + + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + + y = np.zeros_like(x) + + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + + z = np.zeros_like(y, dtype=np.uint8) + z[y > t] = 255 + return z diff --git a/src/flux/annotator/midas/LICENSE b/src/flux/annotator/midas/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..277b5c11be103f028a8d10985139f1da10c2f08e --- /dev/null +++ b/src/flux/annotator/midas/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/flux/annotator/midas/__init__.py b/src/flux/annotator/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36789767f35bcc169c2cbf096e2747539df4f14d --- /dev/null +++ b/src/flux/annotator/midas/__init__.py @@ -0,0 +1,42 @@ +# Midas Depth Estimation +# From https://github.com/isl-org/MiDaS +# MIT LICENSE + +import cv2 +import numpy as np +import torch + +from einops import rearrange +from .api import MiDaSInference + + +class MidasDetector: + def __init__(self): + self.model = MiDaSInference(model_type="dpt_hybrid").cuda() + + def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1): + assert input_image.ndim == 3 + image_depth = input_image + with torch.no_grad(): + image_depth = torch.from_numpy(image_depth).float().cuda() + image_depth = image_depth / 127.5 - 1.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model(image_depth)[0] + + depth_pt = depth.clone() + depth_pt -= torch.min(depth_pt) + depth_pt /= torch.max(depth_pt) + depth_pt = depth_pt.cpu().numpy() + depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) + + depth_np = depth.cpu().numpy() + x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) + y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) + z = np.ones_like(x) * a + x[depth_pt < bg_th] = 0 + y[depth_pt < bg_th] = 0 + normal = np.stack([x, y, z], axis=2) + normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 + normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8) + + return depth_image, normal_image diff --git a/src/flux/annotator/midas/api.py b/src/flux/annotator/midas/api.py new file mode 100644 index 0000000000000000000000000000000000000000..6226a39d80de978162a7238cec1c4d4a64bacbe9 --- /dev/null +++ b/src/flux/annotator/midas/api.py @@ -0,0 +1,168 @@ +# based on https://github.com/isl-org/MiDaS + +import cv2 +import os +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from huggingface_hub import hf_hub_download + +from .midas.dpt_depth import DPTDepthModel +from .midas.midas_net import MidasNet +from .midas.midas_net_custom import MidasNet_small +from .midas.transforms import Resize, NormalizeImage, PrepareForNet +from ...annotator.util import annotator_ckpts_path + + +ISL_PATHS = { + "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"), + "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"), + "midas_v21": "", + "midas_v21_small": "", +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + if not os.path.exists(model_path): + model_path = hf_hub_download("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt") + + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + with torch.no_grad(): + prediction = self.model(x) + return prediction + diff --git a/src/flux/annotator/midas/midas/__init__.py b/src/flux/annotator/midas/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/flux/annotator/midas/midas/base_model.py b/src/flux/annotator/midas/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/src/flux/annotator/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/src/flux/annotator/midas/midas/blocks.py b/src/flux/annotator/midas/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78 --- /dev/null +++ b/src/flux/annotator/midas/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/src/flux/annotator/midas/midas/dpt_depth.py b/src/flux/annotator/midas/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1 --- /dev/null +++ b/src/flux/annotator/midas/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/src/flux/annotator/midas/midas/midas_net.py b/src/flux/annotator/midas/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/src/flux/annotator/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/src/flux/annotator/midas/midas/midas_net_custom.py b/src/flux/annotator/midas/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/src/flux/annotator/midas/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/src/flux/annotator/midas/midas/transforms.py b/src/flux/annotator/midas/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/src/flux/annotator/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/src/flux/annotator/midas/midas/vit.py b/src/flux/annotator/midas/midas/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74 --- /dev/null +++ b/src/flux/annotator/midas/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/src/flux/annotator/midas/utils.py b/src/flux/annotator/midas/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59 --- /dev/null +++ b/src/flux/annotator/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/src/flux/annotator/mlsd/LICENSE b/src/flux/annotator/mlsd/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d855c6db44b4e873eedd750d34fa2eaf22e22363 --- /dev/null +++ b/src/flux/annotator/mlsd/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021-present NAVER Corp. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/src/flux/annotator/mlsd/__init__.py b/src/flux/annotator/mlsd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5028aef051a1e67caae1f0d23a0b0dbca883a7f8 --- /dev/null +++ b/src/flux/annotator/mlsd/__init__.py @@ -0,0 +1,40 @@ +# MLSD Line Detection +# From https://github.com/navervision/mlsd +# Apache-2.0 license + +import cv2 +import numpy as np +import torch +import os + +from einops import rearrange +from huggingface_hub import hf_hub_download +from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny +from .models.mbv2_mlsd_large import MobileV2_MLSD_Large +from .utils import pred_lines + +from ...annotator.util import annotator_ckpts_path + + +class MLSDdetector: + def __init__(self): + model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth") + if not os.path.exists(model_path): + model_path = hf_hub_download("lllyasviel/Annotators", "mlsd_large_512_fp32.pth") + model = MobileV2_MLSD_Large() + model.load_state_dict(torch.load(model_path), strict=True) + self.model = model.cuda().eval() + + def __call__(self, input_image, thr_v, thr_d): + assert input_image.ndim == 3 + img = input_image + img_output = np.zeros_like(img) + try: + with torch.no_grad(): + lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d) + for line in lines: + x_start, y_start, x_end, y_end = [int(val) for val in line] + cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1) + except Exception as e: + pass + return img_output[:, :, 0] diff --git a/src/flux/annotator/mlsd/models/mbv2_mlsd_large.py b/src/flux/annotator/mlsd/models/mbv2_mlsd_large.py new file mode 100644 index 0000000000000000000000000000000000000000..5b9799e7573ca41549b3c3b13ac47b906b369603 --- /dev/null +++ b/src/flux/annotator/mlsd/models/mbv2_mlsd_large.py @@ -0,0 +1,292 @@ +import os +import sys +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from torch.nn import functional as F + + +class BlockTypeA(nn.Module): + def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): + super(BlockTypeA, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c2, out_c2, kernel_size=1), + nn.BatchNorm2d(out_c2), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c1, out_c1, kernel_size=1), + nn.BatchNorm2d(out_c1), + nn.ReLU(inplace=True) + ) + self.upscale = upscale + + def forward(self, a, b): + b = self.conv1(b) + a = self.conv2(a) + if self.upscale: + b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) + return torch.cat((a, b), dim=1) + + +class BlockTypeB(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeB, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), + nn.BatchNorm2d(out_c), + nn.ReLU() + ) + + def forward(self, x): + x = self.conv1(x) + x + x = self.conv2(x) + return x + +class BlockTypeC(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeC, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + self.channel_pad = out_planes - in_planes + self.stride = stride + #padding = (kernel_size - 1) // 2 + + # TFLite uses slightly different padding than PyTorch + if stride == 2: + padding = 0 + else: + padding = (kernel_size - 1) // 2 + + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + + + def forward(self, x): + # TFLite uses different padding + if self.stride == 2: + x = F.pad(x, (0, 1, 0, 1), "constant", 0) + #print(x.shape) + + for module in self: + if not isinstance(module, nn.MaxPool2d): + x = module(x) + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, pretrained=True): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + """ + super(MobileNetV2, self).__init__() + + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + width_mult = 1.0 + round_nearest = 8 + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + #[6, 160, 3, 2], + #[6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(4, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + + self.features = nn.Sequential(*features) + self.fpn_selected = [1, 3, 6, 10, 13] + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + if pretrained: + self._load_pretrained_model() + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + fpn_features = [] + for i, f in enumerate(self.features): + if i > self.fpn_selected[-1]: + break + x = f(x) + if i in self.fpn_selected: + fpn_features.append(x) + + c1, c2, c3, c4, c5 = fpn_features + return c1, c2, c3, c4, c5 + + + def forward(self, x): + return self._forward_impl(x) + + def _load_pretrained_model(self): + pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +class MobileV2_MLSD_Large(nn.Module): + def __init__(self): + super(MobileV2_MLSD_Large, self).__init__() + + self.backbone = MobileNetV2(pretrained=False) + ## A, B + self.block15 = BlockTypeA(in_c1= 64, in_c2= 96, + out_c1= 64, out_c2=64, + upscale=False) + self.block16 = BlockTypeB(128, 64) + + ## A, B + self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64, + out_c1= 64, out_c2= 64) + self.block18 = BlockTypeB(128, 64) + + ## A, B + self.block19 = BlockTypeA(in_c1=24, in_c2=64, + out_c1=64, out_c2=64) + self.block20 = BlockTypeB(128, 64) + + ## A, B, C + self.block21 = BlockTypeA(in_c1=16, in_c2=64, + out_c1=64, out_c2=64) + self.block22 = BlockTypeB(128, 64) + + self.block23 = BlockTypeC(64, 16) + + def forward(self, x): + c1, c2, c3, c4, c5 = self.backbone(x) + + x = self.block15(c4, c5) + x = self.block16(x) + + x = self.block17(c3, x) + x = self.block18(x) + + x = self.block19(c2, x) + x = self.block20(x) + + x = self.block21(c1, x) + x = self.block22(x) + x = self.block23(x) + x = x[:, 7:, :, :] + + return x \ No newline at end of file diff --git a/src/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py b/src/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py new file mode 100644 index 0000000000000000000000000000000000000000..e3ed633f2cc23ea1829a627fdb879ab39f641f83 --- /dev/null +++ b/src/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py @@ -0,0 +1,275 @@ +import os +import sys +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from torch.nn import functional as F + + +class BlockTypeA(nn.Module): + def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): + super(BlockTypeA, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c2, out_c2, kernel_size=1), + nn.BatchNorm2d(out_c2), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c1, out_c1, kernel_size=1), + nn.BatchNorm2d(out_c1), + nn.ReLU(inplace=True) + ) + self.upscale = upscale + + def forward(self, a, b): + b = self.conv1(b) + a = self.conv2(a) + b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) + return torch.cat((a, b), dim=1) + + +class BlockTypeB(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeB, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), + nn.BatchNorm2d(out_c), + nn.ReLU() + ) + + def forward(self, x): + x = self.conv1(x) + x + x = self.conv2(x) + return x + +class BlockTypeC(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeC, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + self.channel_pad = out_planes - in_planes + self.stride = stride + #padding = (kernel_size - 1) // 2 + + # TFLite uses slightly different padding than PyTorch + if stride == 2: + padding = 0 + else: + padding = (kernel_size - 1) // 2 + + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + + + def forward(self, x): + # TFLite uses different padding + if self.stride == 2: + x = F.pad(x, (0, 1, 0, 1), "constant", 0) + #print(x.shape) + + for module in self: + if not isinstance(module, nn.MaxPool2d): + x = module(x) + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, pretrained=True): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + """ + super(MobileNetV2, self).__init__() + + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + width_mult = 1.0 + round_nearest = 8 + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + #[6, 96, 3, 1], + #[6, 160, 3, 2], + #[6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(4, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + self.features = nn.Sequential(*features) + + self.fpn_selected = [3, 6, 10] + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + #if pretrained: + # self._load_pretrained_model() + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + fpn_features = [] + for i, f in enumerate(self.features): + if i > self.fpn_selected[-1]: + break + x = f(x) + if i in self.fpn_selected: + fpn_features.append(x) + + c2, c3, c4 = fpn_features + return c2, c3, c4 + + + def forward(self, x): + return self._forward_impl(x) + + def _load_pretrained_model(self): + pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +class MobileV2_MLSD_Tiny(nn.Module): + def __init__(self): + super(MobileV2_MLSD_Tiny, self).__init__() + + self.backbone = MobileNetV2(pretrained=True) + + self.block12 = BlockTypeA(in_c1= 32, in_c2= 64, + out_c1= 64, out_c2=64) + self.block13 = BlockTypeB(128, 64) + + self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64, + out_c1= 32, out_c2= 32) + self.block15 = BlockTypeB(64, 64) + + self.block16 = BlockTypeC(64, 16) + + def forward(self, x): + c2, c3, c4 = self.backbone(x) + + x = self.block12(c3, c4) + x = self.block13(x) + x = self.block14(c2, x) + x = self.block15(x) + x = self.block16(x) + x = x[:, 7:, :, :] + #print(x.shape) + x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) + + return x \ No newline at end of file diff --git a/src/flux/annotator/mlsd/utils.py b/src/flux/annotator/mlsd/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..848a9fd7f9ff9d909f18c5d3ff55786c5a4b547a --- /dev/null +++ b/src/flux/annotator/mlsd/utils.py @@ -0,0 +1,580 @@ +''' +modified by lihaoweicv +pytorch version +''' + +''' +M-LSD +Copyright 2021-present NAVER Corp. +Apache License v2.0 +''' + +import os +import numpy as np +import cv2 +import torch +from torch.nn import functional as F + + +def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): + ''' + tpMap: + center: tpMap[1, 0, :, :] + displacement: tpMap[1, 1:5, :, :] + ''' + b, c, h, w = tpMap.shape + assert b==1, 'only support bsize==1' + displacement = tpMap[:, 1:5, :, :][0] + center = tpMap[:, 0, :, :] + heat = torch.sigmoid(center) + hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2) + keep = (hmax == heat).float() + heat = heat * keep + heat = heat.reshape(-1, ) + + scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True) + yy = torch.floor_divide(indices, w).unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + ptss = torch.cat((yy, xx),dim=-1) + + ptss = ptss.detach().cpu().numpy() + scores = scores.detach().cpu().numpy() + displacement = displacement.detach().cpu().numpy() + displacement = displacement.transpose((1,2,0)) + return ptss, scores, displacement + + +def pred_lines(image, model, + input_shape=[512, 512], + score_thr=0.10, + dist_thr=20.0): + h, w, _ = image.shape + h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]] + + resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + + resized_image = resized_image.transpose((2,0,1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float().to("cuda:4") + outputs = model(batch_image) + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] + end = vmap[:, :, 2:] + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + segments_list = [] + for center, score in zip(pts, pts_score): + y, x = center + distance = dist_map[y, x] + if score > score_thr and distance > dist_thr: + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + x_start = x + disp_x_start + y_start = y + disp_y_start + x_end = x + disp_x_end + y_end = y + disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + lines = 2 * np.array(segments_list) # 256 > 512 + lines[:, 0] = lines[:, 0] * w_ratio + lines[:, 1] = lines[:, 1] * h_ratio + lines[:, 2] = lines[:, 2] * w_ratio + lines[:, 3] = lines[:, 3] * h_ratio + + return lines + + +def pred_squares(image, + model, + input_shape=[512, 512], + params={'score': 0.06, + 'outside_ratio': 0.28, + 'inside_ratio': 0.45, + 'w_overlap': 0.0, + 'w_degree': 1.95, + 'w_length': 0.0, + 'w_area': 1.86, + 'w_center': 0.14}): + ''' + shape = [height, width] + ''' + h, w, _ = image.shape + original_shape = [h, w] + + resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + resized_image = resized_image.transpose((2, 0, 1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float().cuda() + outputs = model(batch_image) + + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] # (x, y) + end = vmap[:, :, 2:] # (x, y) + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + junc_list = [] + segments_list = [] + for junc, score in zip(pts, pts_score): + y, x = junc + distance = dist_map[y, x] + if score > params['score'] and distance > 20.0: + junc_list.append([x, y]) + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + d_arrow = 1.0 + x_start = x + d_arrow * disp_x_start + y_start = y + d_arrow * disp_y_start + x_end = x + d_arrow * disp_x_end + y_end = y + d_arrow * disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + segments = np.array(segments_list) + + ####### post processing for squares + # 1. get unique lines + point = np.array([[0, 0]]) + point = point[0] + start = segments[:, :2] + end = segments[:, 2:] + diff = start - end + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + + d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10) + theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi + theta[theta < 0.0] += 180 + hough = np.concatenate([d[:, None], theta[:, None]], axis=-1) + + d_quant = 1 + theta_quant = 2 + hough[:, 0] //= d_quant + hough[:, 1] //= theta_quant + _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True) + + acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32') + idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1 + yx_indices = hough[indices, :].astype('int32') + acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts + idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices + + acc_map_np = acc_map + # acc_map = acc_map[None, :, :, None] + # + # ### fast suppression using tensorflow op + # acc_map = tf.constant(acc_map, dtype=tf.float32) + # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map) + # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32) + # flatten_acc_map = tf.reshape(acc_map, [1, -1]) + # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts)) + # _, h, w, _ = acc_map.shape + # y = tf.expand_dims(topk_indices // w, axis=-1) + # x = tf.expand_dims(topk_indices % w, axis=-1) + # yx = tf.concat([y, x], axis=-1) + + ### fast suppression using pytorch op + acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0) + _,_, h, w = acc_map.shape + max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2) + acc_map = acc_map * ( (acc_map == max_acc_map).float() ) + flatten_acc_map = acc_map.reshape([-1, ]) + + scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True) + yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + yx = torch.cat((yy, xx), dim=-1) + + yx = yx.detach().cpu().numpy() + + topk_values = scores.detach().cpu().numpy() + indices = idx_map[yx[:, 0], yx[:, 1]] + basis = 5 // 2 + + merged_segments = [] + for yx_pt, max_indice, value in zip(yx, indices, topk_values): + y, x = yx_pt + if max_indice == -1 or value == 0: + continue + segment_list = [] + for y_offset in range(-basis, basis + 1): + for x_offset in range(-basis, basis + 1): + indice = idx_map[y + y_offset, x + x_offset] + cnt = int(acc_map_np[y + y_offset, x + x_offset]) + if indice != -1: + segment_list.append(segments[indice]) + if cnt > 1: + check_cnt = 1 + current_hough = hough[indice] + for new_indice, new_hough in enumerate(hough): + if (current_hough == new_hough).all() and indice != new_indice: + segment_list.append(segments[new_indice]) + check_cnt += 1 + if check_cnt == cnt: + break + group_segments = np.array(segment_list).reshape([-1, 2]) + sorted_group_segments = np.sort(group_segments, axis=0) + x_min, y_min = sorted_group_segments[0, :] + x_max, y_max = sorted_group_segments[-1, :] + + deg = theta[max_indice] + if deg >= 90: + merged_segments.append([x_min, y_max, x_max, y_min]) + else: + merged_segments.append([x_min, y_min, x_max, y_max]) + + # 2. get intersections + new_segments = np.array(merged_segments) # (x1, y1, x2, y2) + start = new_segments[:, :2] # (x1, y1) + end = new_segments[:, 2:] # (x2, y2) + new_centers = (start + end) / 2.0 + diff = start - end + dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1)) + + # ax + by = c + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + pre_det = a[:, None] * b[None, :] + det = pre_det - np.transpose(pre_det) + + pre_inter_y = a[:, None] * c[None, :] + inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10) + pre_inter_x = c[:, None] * b[None, :] + inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10) + inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32') + + # 3. get corner information + # 3.1 get distance + ''' + dist_segments: + | dist(0), dist(1), dist(2), ...| + dist_inter_to_segment1: + | dist(inter,0), dist(inter,0), dist(inter,0), ... | + | dist(inter,1), dist(inter,1), dist(inter,1), ... | + ... + dist_inter_to_semgnet2: + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + ... + ''' + + dist_inter_to_segment1_start = np.sqrt( + np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment1_end = np.sqrt( + np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_start = np.sqrt( + np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_end = np.sqrt( + np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + + # sort ascending + dist_inter_to_segment1 = np.sort( + np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + dist_inter_to_segment2 = np.sort( + np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + + # 3.2 get degree + inter_to_start = new_centers[:, None, :] - inter_pts + deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi + deg_inter_to_start[deg_inter_to_start < 0.0] += 360 + inter_to_end = new_centers[None, :, :] - inter_pts + deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi + deg_inter_to_end[deg_inter_to_end < 0.0] += 360 + + ''' + B -- G + | | + C -- R + B : blue / G: green / C: cyan / R: red + + 0 -- 1 + | | + 3 -- 2 + ''' + # rename variables + deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end + # sort deg ascending + deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1) + + deg_diff_map = np.abs(deg1_map - deg2_map) + # we only consider the smallest degree of intersect + deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180] + + # define available degree range + deg_range = [60, 120] + + corner_dict = {corner_info: [] for corner_info in range(4)} + inter_points = [] + for i in range(inter_pts.shape[0]): + for j in range(i + 1, inter_pts.shape[1]): + # i, j > line index, always i < j + x, y = inter_pts[i, j, :] + deg1, deg2 = deg_sort[i, j, :] + deg_diff = deg_diff_map[i, j] + + check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1] + + outside_ratio = params['outside_ratio'] # over ratio >>> drop it! + inside_ratio = params['inside_ratio'] # over ratio >>> drop it! + check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \ + (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \ + ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \ + (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio)) + + if check_degree and check_distance: + corner_info = None + + if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \ + (deg2 >= 315 and deg1 >= 45 and deg1 <= 120): + corner_info, color_info = 0, 'blue' + elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225): + corner_info, color_info = 1, 'green' + elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315): + corner_info, color_info = 2, 'black' + elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \ + (deg2 >= 315 and deg1 >= 225 and deg1 <= 315): + corner_info, color_info = 3, 'cyan' + else: + corner_info, color_info = 4, 'red' # we don't use it + continue + + corner_dict[corner_info].append([x, y, i, j]) + inter_points.append([x, y]) + + square_list = [] + connect_list = [] + segments_list = [] + for corner0 in corner_dict[0]: + for corner1 in corner_dict[1]: + connect01 = False + for corner0_line in corner0[2:]: + if corner0_line in corner1[2:]: + connect01 = True + break + if connect01: + for corner2 in corner_dict[2]: + connect12 = False + for corner1_line in corner1[2:]: + if corner1_line in corner2[2:]: + connect12 = True + break + if connect12: + for corner3 in corner_dict[3]: + connect23 = False + for corner2_line in corner2[2:]: + if corner2_line in corner3[2:]: + connect23 = True + break + if connect23: + for corner3_line in corner3[2:]: + if corner3_line in corner0[2:]: + # SQUARE!!! + ''' + 0 -- 1 + | | + 3 -- 2 + square_list: + order: 0 > 1 > 2 > 3 + | x0, y0, x1, y1, x2, y2, x3, y3 | + | x0, y0, x1, y1, x2, y2, x3, y3 | + ... + connect_list: + order: 01 > 12 > 23 > 30 + | line_idx01, line_idx12, line_idx23, line_idx30 | + | line_idx01, line_idx12, line_idx23, line_idx30 | + ... + segments_list: + order: 0 > 1 > 2 > 3 + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + ... + ''' + square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2]) + connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line]) + segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:]) + + def check_outside_inside(segments_info, connect_idx): + # return 'outside or inside', min distance, cover_param, peri_param + if connect_idx == segments_info[0]: + check_dist_mat = dist_inter_to_segment1 + else: + check_dist_mat = dist_inter_to_segment2 + + i, j = segments_info + min_dist, max_dist = check_dist_mat[i, j, :] + connect_dist = dist_segments[connect_idx] + if max_dist > connect_dist: + return 'outside', min_dist, 0, 1 + else: + return 'inside', min_dist, -1, -1 + + top_square = None + + try: + map_size = input_shape[0] / 2 + squares = np.array(square_list).reshape([-1, 4, 2]) + score_array = [] + connect_array = np.array(connect_list) + segments_array = np.array(segments_list).reshape([-1, 4, 2]) + + # get degree of corners: + squares_rollup = np.roll(squares, 1, axis=1) + squares_rolldown = np.roll(squares, -1, axis=1) + vec1 = squares_rollup - squares + normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10) + vec2 = squares_rolldown - squares + normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10) + inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4] + squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4] + + # get square score + overlap_scores = [] + degree_scores = [] + length_scores = [] + + for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree): + ''' + 0 -- 1 + | | + 3 -- 2 + + # segments: [4, 2] + # connects: [4] + ''' + + ###################################### OVERLAP SCORES + cover = 0 + perimeter = 0 + # check 0 > 1 > 2 > 3 + square_length = [] + + for start_idx in range(4): + end_idx = (start_idx + 1) % 4 + + connect_idx = connects[start_idx] # segment idx of segment01 + start_segments = segments[start_idx] + end_segments = segments[end_idx] + + start_point = square[start_idx] + end_point = square[end_idx] + + # check whether outside or inside + start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments, + connect_idx) + end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx) + + cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min + perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min + + square_length.append( + dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min) + + overlap_scores.append(cover / perimeter) + ###################################### + ###################################### DEGREE SCORES + ''' + deg0 vs deg2 + deg1 vs deg3 + ''' + deg0, deg1, deg2, deg3 = degree + deg_ratio1 = deg0 / deg2 + if deg_ratio1 > 1.0: + deg_ratio1 = 1 / deg_ratio1 + deg_ratio2 = deg1 / deg3 + if deg_ratio2 > 1.0: + deg_ratio2 = 1 / deg_ratio2 + degree_scores.append((deg_ratio1 + deg_ratio2) / 2) + ###################################### + ###################################### LENGTH SCORES + ''' + len0 vs len2 + len1 vs len3 + ''' + len0, len1, len2, len3 = square_length + len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0 + len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1 + length_scores.append((len_ratio1 + len_ratio2) / 2) + + ###################################### + + overlap_scores = np.array(overlap_scores) + overlap_scores /= np.max(overlap_scores) + + degree_scores = np.array(degree_scores) + # degree_scores /= np.max(degree_scores) + + length_scores = np.array(length_scores) + + ###################################### AREA SCORES + area_scores = np.reshape(squares, [-1, 4, 2]) + area_x = area_scores[:, :, 0] + area_y = area_scores[:, :, 1] + correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0] + area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1) + area_scores = 0.5 * np.abs(area_scores + correction) + area_scores /= (map_size * map_size) # np.max(area_scores) + ###################################### + + ###################################### CENTER SCORES + centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2] + # squares: [n, 4, 2] + square_centers = np.mean(squares, axis=1) # [n, 2] + center2center = np.sqrt(np.sum((centers - square_centers) ** 2)) + center_scores = center2center / (map_size / np.sqrt(2.0)) + + ''' + score_w = [overlap, degree, area, center, length] + ''' + score_w = [0.0, 1.0, 10.0, 0.5, 1.0] + score_array = params['w_overlap'] * overlap_scores \ + + params['w_degree'] * degree_scores \ + + params['w_area'] * area_scores \ + - params['w_center'] * center_scores \ + + params['w_length'] * length_scores + + best_square = [] + + sorted_idx = np.argsort(score_array)[::-1] + score_array = score_array[sorted_idx] + squares = squares[sorted_idx] + + except Exception as e: + pass + + '''return list + merged_lines, squares, scores + ''' + + try: + new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0] + new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0] + except: + new_segments = [] + + try: + squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1] + squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0] + except: + squares = [] + score_array = [] + + try: + inter_points = np.array(inter_points) + inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1] + inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0] + except: + inter_points = [] + + return new_segments, squares, score_array, inter_points diff --git a/src/flux/annotator/tile/__init__.py b/src/flux/annotator/tile/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c96899289d6e04796140cd7eff9c08e5f693af02 --- /dev/null +++ b/src/flux/annotator/tile/__init__.py @@ -0,0 +1,26 @@ +import random +import cv2 +from .guided_filter import FastGuidedFilter + + +class TileDetector: + # https://huggingface.co/xinsir/controlnet-tile-sdxl-1.0 + def __init__(self): + pass + + def __call__(self, image): + blur_strength = random.sample([i / 10. for i in range(10, 201, 2)], k=1)[0] + radius = random.sample([i for i in range(1, 40, 2)], k=1)[0] + eps = random.sample([i / 1000. for i in range(1, 101, 2)], k=1)[0] + scale_factor = random.sample([i / 10. for i in range(10, 181, 5)], k=1)[0] + + ksize = int(blur_strength) + if ksize % 2 == 0: + ksize += 1 + + if random.random() > 0.5: + image = cv2.GaussianBlur(image, (ksize, ksize), blur_strength / 2) + if random.random() > 0.5: + filter = FastGuidedFilter(image, radius, eps, scale_factor) + image = filter.filter(image) + return image diff --git a/src/flux/annotator/tile/guided_filter.py b/src/flux/annotator/tile/guided_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..3d7172a5e144672eea26551ef75f70b90a2f96d6 --- /dev/null +++ b/src/flux/annotator/tile/guided_filter.py @@ -0,0 +1,280 @@ +# -*- coding: utf-8 -*- +## @package guided_filter.core.filters +# +# Implementation of guided filter. +# * GuidedFilter: Original guided filter. +# * FastGuidedFilter: Fast version of the guided filter. +# @author tody +# @date 2015/08/26 + +import numpy as np +import cv2 + +## Convert image into float32 type. +def to32F(img): + if img.dtype == np.float32: + return img + return (1.0 / 255.0) * np.float32(img) + +## Convert image into uint8 type. +def to8U(img): + if img.dtype == np.uint8: + return img + return np.clip(np.uint8(255.0 * img), 0, 255) + +## Return if the input image is gray or not. +def _isGray(I): + return len(I.shape) == 2 + + +## Return down sampled image. +# @param scale (w/s, h/s) image will be created. +# @param shape I.shape[:2]=(h, w). numpy friendly size parameter. +def _downSample(I, scale=4, shape=None): + if shape is not None: + h, w = shape + return cv2.resize(I, (w, h), interpolation=cv2.INTER_NEAREST) + + h, w = I.shape[:2] + return cv2.resize(I, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_NEAREST) + + +## Return up sampled image. +# @param scale (w*s, h*s) image will be created. +# @param shape I.shape[:2]=(h, w). numpy friendly size parameter. +def _upSample(I, scale=2, shape=None): + if shape is not None: + h, w = shape + return cv2.resize(I, (w, h), interpolation=cv2.INTER_LINEAR) + + h, w = I.shape[:2] + return cv2.resize(I, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR) + +## Fast guide filter. +class FastGuidedFilter: + ## Constructor. + # @param I Input guidance image. Color or gray. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + # @param scale Down sampled scale. + def __init__(self, I, radius=5, epsilon=0.4, scale=4): + I_32F = to32F(I) + self._I = I_32F + h, w = I.shape[:2] + + I_sub = _downSample(I_32F, scale) + + self._I_sub = I_sub + radius = int(radius / scale) + + if _isGray(I): + self._guided_filter = GuidedFilterGray(I_sub, radius, epsilon) + else: + self._guided_filter = GuidedFilterColor(I_sub, radius, epsilon) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + p_32F = to32F(p) + shape_original = p.shape[:2] + + p_sub = _downSample(p_32F, shape=self._I_sub.shape[:2]) + + if _isGray(p_sub): + return self._filterGray(p_sub, shape_original) + + cs = p.shape[2] + q = np.array(p_32F) + + for ci in range(cs): + q[:, :, ci] = self._filterGray(p_sub[:, :, ci], shape_original) + return to8U(q) + + def _filterGray(self, p_sub, shape_original): + ab_sub = self._guided_filter._computeCoefficients(p_sub) + ab = [_upSample(abi, shape=shape_original) for abi in ab_sub] + return self._guided_filter._computeOutput(ab, self._I) + + +## Guide filter. +class GuidedFilter: + ## Constructor. + # @param I Input guidance image. Color or gray. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + def __init__(self, I, radius=5, epsilon=0.4): + I_32F = to32F(I) + + if _isGray(I): + self._guided_filter = GuidedFilterGray(I_32F, radius, epsilon) + else: + self._guided_filter = GuidedFilterColor(I_32F, radius, epsilon) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + return to8U(self._guided_filter.filter(p)) + + +## Common parts of guided filter. +# +# This class is used by guided_filter class. GuidedFilterGray and GuidedFilterColor. +# Based on guided_filter._computeCoefficients, guided_filter._computeOutput, +# GuidedFilterCommon.filter computes filtered image for color and gray. +class GuidedFilterCommon: + def __init__(self, guided_filter): + self._guided_filter = guided_filter + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + p_32F = to32F(p) + if _isGray(p_32F): + return self._filterGray(p_32F) + + cs = p.shape[2] + q = np.array(p_32F) + + for ci in range(cs): + q[:, :, ci] = self._filterGray(p_32F[:, :, ci]) + return q + + def _filterGray(self, p): + ab = self._guided_filter._computeCoefficients(p) + return self._guided_filter._computeOutput(ab, self._guided_filter._I) + + +## Guided filter for gray guidance image. +class GuidedFilterGray: + # @param I Input gray guidance image. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + def __init__(self, I, radius=5, epsilon=0.4): + self._radius = 2 * radius + 1 + self._epsilon = epsilon + self._I = to32F(I) + self._initFilter() + self._filter_common = GuidedFilterCommon(self) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + return self._filter_common.filter(p) + + def _initFilter(self): + I = self._I + r = self._radius + self._I_mean = cv2.blur(I, (r, r)) + I_mean_sq = cv2.blur(I ** 2, (r, r)) + self._I_var = I_mean_sq - self._I_mean ** 2 + + def _computeCoefficients(self, p): + r = self._radius + p_mean = cv2.blur(p, (r, r)) + p_cov = p_mean - self._I_mean * p_mean + a = p_cov / (self._I_var + self._epsilon) + b = p_mean - a * self._I_mean + a_mean = cv2.blur(a, (r, r)) + b_mean = cv2.blur(b, (r, r)) + return a_mean, b_mean + + def _computeOutput(self, ab, I): + a_mean, b_mean = ab + return a_mean * I + b_mean + + +## Guided filter for color guidance image. +class GuidedFilterColor: + # @param I Input color guidance image. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + def __init__(self, I, radius=5, epsilon=0.2): + self._radius = 2 * radius + 1 + self._epsilon = epsilon + self._I = to32F(I) + self._initFilter() + self._filter_common = GuidedFilterCommon(self) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + return self._filter_common.filter(p) + + def _initFilter(self): + I = self._I + r = self._radius + eps = self._epsilon + + Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] + + self._Ir_mean = cv2.blur(Ir, (r, r)) + self._Ig_mean = cv2.blur(Ig, (r, r)) + self._Ib_mean = cv2.blur(Ib, (r, r)) + + Irr_var = cv2.blur(Ir ** 2, (r, r)) - self._Ir_mean ** 2 + eps + Irg_var = cv2.blur(Ir * Ig, (r, r)) - self._Ir_mean * self._Ig_mean + Irb_var = cv2.blur(Ir * Ib, (r, r)) - self._Ir_mean * self._Ib_mean + Igg_var = cv2.blur(Ig * Ig, (r, r)) - self._Ig_mean * self._Ig_mean + eps + Igb_var = cv2.blur(Ig * Ib, (r, r)) - self._Ig_mean * self._Ib_mean + Ibb_var = cv2.blur(Ib * Ib, (r, r)) - self._Ib_mean * self._Ib_mean + eps + + Irr_inv = Igg_var * Ibb_var - Igb_var * Igb_var + Irg_inv = Igb_var * Irb_var - Irg_var * Ibb_var + Irb_inv = Irg_var * Igb_var - Igg_var * Irb_var + Igg_inv = Irr_var * Ibb_var - Irb_var * Irb_var + Igb_inv = Irb_var * Irg_var - Irr_var * Igb_var + Ibb_inv = Irr_var * Igg_var - Irg_var * Irg_var + + I_cov = Irr_inv * Irr_var + Irg_inv * Irg_var + Irb_inv * Irb_var + Irr_inv /= I_cov + Irg_inv /= I_cov + Irb_inv /= I_cov + Igg_inv /= I_cov + Igb_inv /= I_cov + Ibb_inv /= I_cov + + self._Irr_inv = Irr_inv + self._Irg_inv = Irg_inv + self._Irb_inv = Irb_inv + self._Igg_inv = Igg_inv + self._Igb_inv = Igb_inv + self._Ibb_inv = Ibb_inv + + def _computeCoefficients(self, p): + r = self._radius + I = self._I + Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] + + p_mean = cv2.blur(p, (r, r)) + + Ipr_mean = cv2.blur(Ir * p, (r, r)) + Ipg_mean = cv2.blur(Ig * p, (r, r)) + Ipb_mean = cv2.blur(Ib * p, (r, r)) + + Ipr_cov = Ipr_mean - self._Ir_mean * p_mean + Ipg_cov = Ipg_mean - self._Ig_mean * p_mean + Ipb_cov = Ipb_mean - self._Ib_mean * p_mean + + ar = self._Irr_inv * Ipr_cov + self._Irg_inv * Ipg_cov + self._Irb_inv * Ipb_cov + ag = self._Irg_inv * Ipr_cov + self._Igg_inv * Ipg_cov + self._Igb_inv * Ipb_cov + ab = self._Irb_inv * Ipr_cov + self._Igb_inv * Ipg_cov + self._Ibb_inv * Ipb_cov + b = p_mean - ar * self._Ir_mean - ag * self._Ig_mean - ab * self._Ib_mean + + ar_mean = cv2.blur(ar, (r, r)) + ag_mean = cv2.blur(ag, (r, r)) + ab_mean = cv2.blur(ab, (r, r)) + b_mean = cv2.blur(b, (r, r)) + + return ar_mean, ag_mean, ab_mean, b_mean + + def _computeOutput(self, ab, I): + ar_mean, ag_mean, ab_mean, b_mean = ab + + Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] + + q = (ar_mean * Ir + + ag_mean * Ig + + ab_mean * Ib + + b_mean) + + return q diff --git a/src/flux/annotator/util.py b/src/flux/annotator/util.py new file mode 100644 index 0000000000000000000000000000000000000000..90831643d19cc1b9b0940df3d4fd4d846ba74a05 --- /dev/null +++ b/src/flux/annotator/util.py @@ -0,0 +1,38 @@ +import numpy as np +import cv2 +import os + + +annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img diff --git a/src/flux/annotator/zoe/LICENSE b/src/flux/annotator/zoe/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7a1e90d007836c327846ce8e5151013b115042ab --- /dev/null +++ b/src/flux/annotator/zoe/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Intelligent Systems Lab Org + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/src/flux/annotator/zoe/__init__.py b/src/flux/annotator/zoe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7628090932e35bdd71d041069ae62f6a731f60d4 --- /dev/null +++ b/src/flux/annotator/zoe/__init__.py @@ -0,0 +1,48 @@ +# ZoeDepth +# https://github.com/isl-org/ZoeDepth + +import os +import cv2 +import numpy as np +import torch + +from einops import rearrange +from .zoedepth.models.zoedepth.zoedepth_v1 import ZoeDepth +from .zoedepth.utils.config import get_config +from ...annotator.util import annotator_ckpts_path +from huggingface_hub import hf_hub_download + + +class ZoeDetector: + def __init__(self): + model_path = os.path.join(annotator_ckpts_path, "ZoeD_M12_N.pt") + if not os.path.exists(model_path): + model_path = hf_hub_download("lllyasviel/Annotators", "ZoeD_M12_N.pt") + conf = get_config("zoedepth", "infer") + model = ZoeDepth.build_from_config(conf) + model.load_state_dict(torch.load(model_path)['model'], strict=False) + model = model.cuda() + model.device = 'cuda' + model.eval() + self.model = model + + def __call__(self, input_image): + assert input_image.ndim == 3 + image_depth = input_image + with torch.no_grad(): + image_depth = torch.from_numpy(image_depth).float().cuda() + image_depth = image_depth / 255.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model.infer(image_depth) + + depth = depth[0, 0].cpu().numpy() + + vmin = np.percentile(depth, 2) + vmax = np.percentile(depth, 85) + + depth -= vmin + depth /= vmax - vmin + depth = 1.0 - depth + depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8) + + return depth_image diff --git a/src/flux/annotator/zoe/zoedepth/data/__init__.py b/src/flux/annotator/zoe/zoedepth/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/src/flux/annotator/zoe/zoedepth/data/data_mono.py b/src/flux/annotator/zoe/zoedepth/data/data_mono.py new file mode 100644 index 0000000000000000000000000000000000000000..80a8486f239a35331df553f490e213f9bf71e735 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/data_mono.py @@ -0,0 +1,573 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +# This file is partly inspired from BTS (https://github.com/cleinc/bts/blob/master/pytorch/bts_dataloader.py); author: Jin Han Lee + +import itertools +import os +import random + +import numpy as np +import cv2 +import torch +import torch.nn as nn +import torch.utils.data.distributed +from zoedepth.utils.easydict import EasyDict as edict +from PIL import Image, ImageOps +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + +from zoedepth.utils.config import change_dataset + +from .ddad import get_ddad_loader +from .diml_indoor_test import get_diml_indoor_loader +from .diml_outdoor_test import get_diml_outdoor_loader +from .diode import get_diode_loader +from .hypersim import get_hypersim_loader +from .ibims import get_ibims_loader +from .sun_rgbd_loader import get_sunrgbd_loader +from .vkitti import get_vkitti_loader +from .vkitti2 import get_vkitti2_loader + +from .preprocess import CropParams, get_white_border, get_black_border + + +def _is_pil_image(img): + return isinstance(img, Image.Image) + + +def _is_numpy_image(img): + return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) + + +def preprocessing_transforms(mode, **kwargs): + return transforms.Compose([ + ToTensor(mode=mode, **kwargs) + ]) + + +class DepthDataLoader(object): + def __init__(self, config, mode, device='cpu', transform=None, **kwargs): + """ + Data loader for depth datasets + + Args: + config (dict): Config dictionary. Refer to utils/config.py + mode (str): "train" or "online_eval" + device (str, optional): Device to load the data on. Defaults to 'cpu'. + transform (torchvision.transforms, optional): Transform to apply to the data. Defaults to None. + """ + + self.config = config + + if config.dataset == 'ibims': + self.data = get_ibims_loader(config, batch_size=1, num_workers=1) + return + + if config.dataset == 'sunrgbd': + self.data = get_sunrgbd_loader( + data_dir_root=config.sunrgbd_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'diml_indoor': + self.data = get_diml_indoor_loader( + data_dir_root=config.diml_indoor_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'diml_outdoor': + self.data = get_diml_outdoor_loader( + data_dir_root=config.diml_outdoor_root, batch_size=1, num_workers=1) + return + + if "diode" in config.dataset: + self.data = get_diode_loader( + config[config.dataset+"_root"], batch_size=1, num_workers=1) + return + + if config.dataset == 'hypersim_test': + self.data = get_hypersim_loader( + config.hypersim_test_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'vkitti': + self.data = get_vkitti_loader( + config.vkitti_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'vkitti2': + self.data = get_vkitti2_loader( + config.vkitti2_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'ddad': + self.data = get_ddad_loader(config.ddad_root, resize_shape=( + 352, 1216), batch_size=1, num_workers=1) + return + + img_size = self.config.get("img_size", None) + img_size = img_size if self.config.get( + "do_input_resize", False) else None + + if transform is None: + transform = preprocessing_transforms(mode, size=img_size) + + if mode == 'train': + + Dataset = DataLoadPreprocess + self.training_samples = Dataset( + config, mode, transform=transform, device=device) + + if config.distributed: + self.train_sampler = torch.utils.data.distributed.DistributedSampler( + self.training_samples) + else: + self.train_sampler = None + + self.data = DataLoader(self.training_samples, + batch_size=config.batch_size, + shuffle=(self.train_sampler is None), + num_workers=config.workers, + pin_memory=True, + persistent_workers=True, + # prefetch_factor=2, + sampler=self.train_sampler) + + elif mode == 'online_eval': + self.testing_samples = DataLoadPreprocess( + config, mode, transform=transform) + if config.distributed: # redundant. here only for readability and to be more explicit + # Give whole test set to all processes (and report evaluation only on one) regardless + self.eval_sampler = None + else: + self.eval_sampler = None + self.data = DataLoader(self.testing_samples, 1, + shuffle=kwargs.get("shuffle_test", False), + num_workers=1, + pin_memory=False, + sampler=self.eval_sampler) + + elif mode == 'test': + self.testing_samples = DataLoadPreprocess( + config, mode, transform=transform) + self.data = DataLoader(self.testing_samples, + 1, shuffle=False, num_workers=1) + + else: + print( + 'mode should be one of \'train, test, online_eval\'. Got {}'.format(mode)) + + +def repetitive_roundrobin(*iterables): + """ + cycles through iterables but sample wise + first yield first sample from first iterable then first sample from second iterable and so on + then second sample from first iterable then second sample from second iterable and so on + + If one iterable is shorter than the others, it is repeated until all iterables are exhausted + repetitive_roundrobin('ABC', 'D', 'EF') --> A D E B D F C D E + """ + # Repetitive roundrobin + iterables_ = [iter(it) for it in iterables] + exhausted = [False] * len(iterables) + while not all(exhausted): + for i, it in enumerate(iterables_): + try: + yield next(it) + except StopIteration: + exhausted[i] = True + iterables_[i] = itertools.cycle(iterables[i]) + # First elements may get repeated if one iterable is shorter than the others + yield next(iterables_[i]) + + +class RepetitiveRoundRobinDataLoader(object): + def __init__(self, *dataloaders): + self.dataloaders = dataloaders + + def __iter__(self): + return repetitive_roundrobin(*self.dataloaders) + + def __len__(self): + # First samples get repeated, thats why the plus one + return len(self.dataloaders) * (max(len(dl) for dl in self.dataloaders) + 1) + + +class MixedNYUKITTI(object): + def __init__(self, config, mode, device='cpu', **kwargs): + config = edict(config) + config.workers = config.workers // 2 + self.config = config + nyu_conf = change_dataset(edict(config), 'nyu') + kitti_conf = change_dataset(edict(config), 'kitti') + + # make nyu default for testing + self.config = config = nyu_conf + img_size = self.config.get("img_size", None) + img_size = img_size if self.config.get( + "do_input_resize", False) else None + if mode == 'train': + nyu_loader = DepthDataLoader( + nyu_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data + kitti_loader = DepthDataLoader( + kitti_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data + # It has been changed to repetitive roundrobin + self.data = RepetitiveRoundRobinDataLoader( + nyu_loader, kitti_loader) + else: + self.data = DepthDataLoader(nyu_conf, mode, device=device).data + + +def remove_leading_slash(s): + if s[0] == '/' or s[0] == '\\': + return s[1:] + return s + + +class CachedReader: + def __init__(self, shared_dict=None): + if shared_dict: + self._cache = shared_dict + else: + self._cache = {} + + def open(self, fpath): + im = self._cache.get(fpath, None) + if im is None: + im = self._cache[fpath] = Image.open(fpath) + return im + + +class ImReader: + def __init__(self): + pass + + # @cache + def open(self, fpath): + return Image.open(fpath) + + +class DataLoadPreprocess(Dataset): + def __init__(self, config, mode, transform=None, is_for_online_eval=False, **kwargs): + self.config = config + if mode == 'online_eval': + with open(config.filenames_file_eval, 'r') as f: + self.filenames = f.readlines() + else: + with open(config.filenames_file, 'r') as f: + self.filenames = f.readlines() + + self.mode = mode + self.transform = transform + self.to_tensor = ToTensor(mode) + self.is_for_online_eval = is_for_online_eval + if config.use_shared_dict: + self.reader = CachedReader(config.shared_dict) + else: + self.reader = ImReader() + + def postprocess(self, sample): + return sample + + def __getitem__(self, idx): + sample_path = self.filenames[idx] + focal = float(sample_path.split()[2]) + sample = {} + + if self.mode == 'train': + if self.config.dataset == 'kitti' and self.config.use_right and random.random() > 0.5: + image_path = os.path.join( + self.config.data_path, remove_leading_slash(sample_path.split()[3])) + depth_path = os.path.join( + self.config.gt_path, remove_leading_slash(sample_path.split()[4])) + else: + image_path = os.path.join( + self.config.data_path, remove_leading_slash(sample_path.split()[0])) + depth_path = os.path.join( + self.config.gt_path, remove_leading_slash(sample_path.split()[1])) + + image = self.reader.open(image_path) + depth_gt = self.reader.open(depth_path) + w, h = image.size + + if self.config.do_kb_crop: + height = image.height + width = image.width + top_margin = int(height - 352) + left_margin = int((width - 1216) / 2) + depth_gt = depth_gt.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + image = image.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + + # Avoid blank boundaries due to pixel registration? + # Train images have white border. Test images have black border. + if self.config.dataset == 'nyu' and self.config.avoid_boundary: + # print("Avoiding Blank Boundaries!") + # We just crop and pad again with reflect padding to original size + # original_size = image.size + crop_params = get_white_border(np.array(image, dtype=np.uint8)) + image = image.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom)) + depth_gt = depth_gt.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom)) + + # Use reflect padding to fill the blank + image = np.array(image) + image = np.pad(image, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right), (0, 0)), mode='reflect') + image = Image.fromarray(image) + + depth_gt = np.array(depth_gt) + depth_gt = np.pad(depth_gt, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right)), 'constant', constant_values=0) + depth_gt = Image.fromarray(depth_gt) + + + if self.config.do_random_rotate and (self.config.aug): + random_angle = (random.random() - 0.5) * 2 * self.config.degree + image = self.rotate_image(image, random_angle) + depth_gt = self.rotate_image( + depth_gt, random_angle, flag=Image.NEAREST) + + image = np.asarray(image, dtype=np.float32) / 255.0 + depth_gt = np.asarray(depth_gt, dtype=np.float32) + depth_gt = np.expand_dims(depth_gt, axis=2) + + if self.config.dataset == 'nyu': + depth_gt = depth_gt / 1000.0 + else: + depth_gt = depth_gt / 256.0 + + if self.config.aug and (self.config.random_crop): + image, depth_gt = self.random_crop( + image, depth_gt, self.config.input_height, self.config.input_width) + + if self.config.aug and self.config.random_translate: + # print("Random Translation!") + image, depth_gt = self.random_translate(image, depth_gt, self.config.max_translation) + + image, depth_gt = self.train_preprocess(image, depth_gt) + mask = np.logical_and(depth_gt > self.config.min_depth, + depth_gt < self.config.max_depth).squeeze()[None, ...] + sample = {'image': image, 'depth': depth_gt, 'focal': focal, + 'mask': mask, **sample} + + else: + if self.mode == 'online_eval': + data_path = self.config.data_path_eval + else: + data_path = self.config.data_path + + image_path = os.path.join( + data_path, remove_leading_slash(sample_path.split()[0])) + image = np.asarray(self.reader.open(image_path), + dtype=np.float32) / 255.0 + + if self.mode == 'online_eval': + gt_path = self.config.gt_path_eval + depth_path = os.path.join( + gt_path, remove_leading_slash(sample_path.split()[1])) + has_valid_depth = False + try: + depth_gt = self.reader.open(depth_path) + has_valid_depth = True + except IOError: + depth_gt = False + # print('Missing gt for {}'.format(image_path)) + + if has_valid_depth: + depth_gt = np.asarray(depth_gt, dtype=np.float32) + depth_gt = np.expand_dims(depth_gt, axis=2) + if self.config.dataset == 'nyu': + depth_gt = depth_gt / 1000.0 + else: + depth_gt = depth_gt / 256.0 + + mask = np.logical_and( + depth_gt >= self.config.min_depth, depth_gt <= self.config.max_depth).squeeze()[None, ...] + else: + mask = False + + if self.config.do_kb_crop: + height = image.shape[0] + width = image.shape[1] + top_margin = int(height - 352) + left_margin = int((width - 1216) / 2) + image = image[top_margin:top_margin + 352, + left_margin:left_margin + 1216, :] + if self.mode == 'online_eval' and has_valid_depth: + depth_gt = depth_gt[top_margin:top_margin + + 352, left_margin:left_margin + 1216, :] + + if self.mode == 'online_eval': + sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth, + 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1], + 'mask': mask} + else: + sample = {'image': image, 'focal': focal} + + if (self.mode == 'train') or ('has_valid_depth' in sample and sample['has_valid_depth']): + mask = np.logical_and(depth_gt > self.config.min_depth, + depth_gt < self.config.max_depth).squeeze()[None, ...] + sample['mask'] = mask + + if self.transform: + sample = self.transform(sample) + + sample = self.postprocess(sample) + sample['dataset'] = self.config.dataset + sample = {**sample, 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1]} + + return sample + + def rotate_image(self, image, angle, flag=Image.BILINEAR): + result = image.rotate(angle, resample=flag) + return result + + def random_crop(self, img, depth, height, width): + assert img.shape[0] >= height + assert img.shape[1] >= width + assert img.shape[0] == depth.shape[0] + assert img.shape[1] == depth.shape[1] + x = random.randint(0, img.shape[1] - width) + y = random.randint(0, img.shape[0] - height) + img = img[y:y + height, x:x + width, :] + depth = depth[y:y + height, x:x + width, :] + + return img, depth + + def random_translate(self, img, depth, max_t=20): + assert img.shape[0] == depth.shape[0] + assert img.shape[1] == depth.shape[1] + p = self.config.translate_prob + do_translate = random.random() + if do_translate > p: + return img, depth + x = random.randint(-max_t, max_t) + y = random.randint(-max_t, max_t) + M = np.float32([[1, 0, x], [0, 1, y]]) + # print(img.shape, depth.shape) + img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0])) + depth = cv2.warpAffine(depth, M, (depth.shape[1], depth.shape[0])) + depth = depth.squeeze()[..., None] # add channel dim back. Affine warp removes it + # print("after", img.shape, depth.shape) + return img, depth + + def train_preprocess(self, image, depth_gt): + if self.config.aug: + # Random flipping + do_flip = random.random() + if do_flip > 0.5: + image = (image[:, ::-1, :]).copy() + depth_gt = (depth_gt[:, ::-1, :]).copy() + + # Random gamma, brightness, color augmentation + do_augment = random.random() + if do_augment > 0.5: + image = self.augment_image(image) + + return image, depth_gt + + def augment_image(self, image): + # gamma augmentation + gamma = random.uniform(0.9, 1.1) + image_aug = image ** gamma + + # brightness augmentation + if self.config.dataset == 'nyu': + brightness = random.uniform(0.75, 1.25) + else: + brightness = random.uniform(0.9, 1.1) + image_aug = image_aug * brightness + + # color augmentation + colors = np.random.uniform(0.9, 1.1, size=3) + white = np.ones((image.shape[0], image.shape[1])) + color_image = np.stack([white * colors[i] for i in range(3)], axis=2) + image_aug *= color_image + image_aug = np.clip(image_aug, 0, 1) + + return image_aug + + def __len__(self): + return len(self.filenames) + + +class ToTensor(object): + def __init__(self, mode, do_normalize=False, size=None): + self.mode = mode + self.normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if do_normalize else nn.Identity() + self.size = size + if size is not None: + self.resize = transforms.Resize(size=size) + else: + self.resize = nn.Identity() + + def __call__(self, sample): + image, focal = sample['image'], sample['focal'] + image = self.to_tensor(image) + image = self.normalize(image) + image = self.resize(image) + + if self.mode == 'test': + return {'image': image, 'focal': focal} + + depth = sample['depth'] + if self.mode == 'train': + depth = self.to_tensor(depth) + return {**sample, 'image': image, 'depth': depth, 'focal': focal} + else: + has_valid_depth = sample['has_valid_depth'] + image = self.resize(image) + return {**sample, 'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth, + 'image_path': sample['image_path'], 'depth_path': sample['depth_path']} + + def to_tensor(self, pic): + if not (_is_pil_image(pic) or _is_numpy_image(pic)): + raise TypeError( + 'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img diff --git a/src/flux/annotator/zoe/zoedepth/data/ddad.py b/src/flux/annotator/zoe/zoedepth/data/ddad.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd0492bdec767685d3a21992b4a26e62d002d97 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/ddad.py @@ -0,0 +1,117 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self, resize_shape): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + self.resize = transforms.Resize(resize_shape) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "ddad"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class DDAD(Dataset): + def __init__(self, data_dir_root, resize_shape): + import glob + + # image paths are of the form /{outleft, depthmap}/*.png + self.image_files = glob.glob(os.path.join(data_dir_root, '*.png')) + self.depth_files = [r.replace("_rgb.png", "_depth.npy") + for r in self.image_files] + self.transform = ToTensor(resize_shape) + + def __getitem__(self, idx): + + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.load(depth_path) # meters + + # depth[depth > 8] = -1 + depth = depth[..., None] + + sample = dict(image=image, depth=depth) + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_ddad_loader(data_dir_root, resize_shape, batch_size=1, **kwargs): + dataset = DDAD(data_dir_root, resize_shape) + return DataLoader(dataset, batch_size, **kwargs) diff --git a/src/flux/annotator/zoe/zoedepth/data/diml_indoor_test.py b/src/flux/annotator/zoe/zoedepth/data/diml_indoor_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f720ad9aefaee78ef4ec363dfef0f82ace850a6d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/diml_indoor_test.py @@ -0,0 +1,125 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + self.resize = transforms.Resize((480, 640)) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "diml_indoor"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class DIML_Indoor(Dataset): + def __init__(self, data_dir_root): + import glob + + # image paths are of the form /{HR, LR}//{color, depth_filled}/*.png + self.image_files = glob.glob(os.path.join( + data_dir_root, "LR", '*', 'color', '*.png')) + self.depth_files = [r.replace("color", "depth_filled").replace( + "_c.png", "_depth_filled.png") for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.asarray(Image.open(depth_path), + dtype='uint16') / 1000.0 # mm to meters + + # print(np.shape(image)) + # print(np.shape(depth)) + + # depth[depth > 8] = -1 + depth = depth[..., None] + + sample = dict(image=image, depth=depth) + + # return sample + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_diml_indoor_loader(data_dir_root, batch_size=1, **kwargs): + dataset = DIML_Indoor(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + +# get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/HR") +# get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/LR") diff --git a/src/flux/annotator/zoe/zoedepth/data/diml_outdoor_test.py b/src/flux/annotator/zoe/zoedepth/data/diml_outdoor_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8670b48f5febafb819dac22848ad79ccb5dd5ae4 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/diml_outdoor_test.py @@ -0,0 +1,114 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + return {'image': image, 'depth': depth, 'dataset': "diml_outdoor"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class DIML_Outdoor(Dataset): + def __init__(self, data_dir_root): + import glob + + # image paths are of the form /{outleft, depthmap}/*.png + self.image_files = glob.glob(os.path.join( + data_dir_root, "*", 'outleft', '*.png')) + self.depth_files = [r.replace("outleft", "depthmap") + for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.asarray(Image.open(depth_path), + dtype='uint16') / 1000.0 # mm to meters + + # depth[depth > 8] = -1 + depth = depth[..., None] + + sample = dict(image=image, depth=depth, dataset="diml_outdoor") + + # return sample + return self.transform(sample) + + def __len__(self): + return len(self.image_files) + + +def get_diml_outdoor_loader(data_dir_root, batch_size=1, **kwargs): + dataset = DIML_Outdoor(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + +# get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/HR") +# get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/LR") diff --git a/src/flux/annotator/zoe/zoedepth/data/diode.py b/src/flux/annotator/zoe/zoedepth/data/diode.py new file mode 100644 index 0000000000000000000000000000000000000000..1510c87116b8f70ce2e1428873a8e4da042bee23 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/diode.py @@ -0,0 +1,125 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + self.resize = transforms.Resize(480) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "diode"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class DIODE(Dataset): + def __init__(self, data_dir_root): + import glob + + # image paths are of the form /scene_#/scan_#/*.png + self.image_files = glob.glob( + os.path.join(data_dir_root, '*', '*', '*.png')) + self.depth_files = [r.replace(".png", "_depth.npy") + for r in self.image_files] + self.depth_mask_files = [ + r.replace(".png", "_depth_mask.npy") for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + depth_mask_path = self.depth_mask_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.load(depth_path) # in meters + valid = np.load(depth_mask_path) # binary + + # depth[depth > 8] = -1 + # depth = depth[..., None] + + sample = dict(image=image, depth=depth, valid=valid) + + # return sample + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_diode_loader(data_dir_root, batch_size=1, **kwargs): + dataset = DIODE(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + +# get_diode_loader(data_dir_root="datasets/diode/val/outdoor") diff --git a/src/flux/annotator/zoe/zoedepth/data/hypersim.py b/src/flux/annotator/zoe/zoedepth/data/hypersim.py new file mode 100644 index 0000000000000000000000000000000000000000..4334198971830200f72ea2910d03f4c7d6a43334 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/hypersim.py @@ -0,0 +1,138 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import glob +import os + +import h5py +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +def hypersim_distance_to_depth(npyDistance): + intWidth, intHeight, fltFocal = 1024, 768, 886.81 + + npyImageplaneX = np.linspace((-0.5 * intWidth) + 0.5, (0.5 * intWidth) - 0.5, intWidth).reshape( + 1, intWidth).repeat(intHeight, 0).astype(np.float32)[:, :, None] + npyImageplaneY = np.linspace((-0.5 * intHeight) + 0.5, (0.5 * intHeight) - 0.5, + intHeight).reshape(intHeight, 1).repeat(intWidth, 1).astype(np.float32)[:, :, None] + npyImageplaneZ = np.full([intHeight, intWidth, 1], fltFocal, np.float32) + npyImageplane = np.concatenate( + [npyImageplaneX, npyImageplaneY, npyImageplaneZ], 2) + + npyDepth = npyDistance / np.linalg.norm(npyImageplane, 2, 2) * fltFocal + return npyDepth + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x: x + self.resize = transforms.Resize((480, 640)) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "hypersim"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class HyperSim(Dataset): + def __init__(self, data_dir_root): + # image paths are of the form //images/scene_cam_#_final_preview/*.tonemap.jpg + # depth paths are of the form //images/scene_cam_#_final_preview/*.depth_meters.hdf5 + self.image_files = glob.glob(os.path.join( + data_dir_root, '*', 'images', 'scene_cam_*_final_preview', '*.tonemap.jpg')) + self.depth_files = [r.replace("_final_preview", "_geometry_hdf5").replace( + ".tonemap.jpg", ".depth_meters.hdf5") for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + + # depth from hdf5 + depth_fd = h5py.File(depth_path, "r") + # in meters (Euclidean distance) + distance_meters = np.array(depth_fd['dataset']) + depth = hypersim_distance_to_depth( + distance_meters) # in meters (planar depth) + + # depth[depth > 8] = -1 + depth = depth[..., None] + + sample = dict(image=image, depth=depth) + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_hypersim_loader(data_dir_root, batch_size=1, **kwargs): + dataset = HyperSim(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) diff --git a/src/flux/annotator/zoe/zoedepth/data/ibims.py b/src/flux/annotator/zoe/zoedepth/data/ibims.py new file mode 100644 index 0000000000000000000000000000000000000000..b66abfabcf4cfc617d4a60ec818780c3548d9920 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/ibims.py @@ -0,0 +1,81 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms as T + + +class iBims(Dataset): + def __init__(self, config): + root_folder = config.ibims_root + with open(os.path.join(root_folder, "imagelist.txt"), 'r') as f: + imglist = f.read().split() + + samples = [] + for basename in imglist: + img_path = os.path.join(root_folder, 'rgb', basename + ".png") + depth_path = os.path.join(root_folder, 'depth', basename + ".png") + valid_mask_path = os.path.join( + root_folder, 'mask_invalid', basename+".png") + transp_mask_path = os.path.join( + root_folder, 'mask_transp', basename+".png") + + samples.append( + (img_path, depth_path, valid_mask_path, transp_mask_path)) + + self.samples = samples + # self.normalize = T.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + + def __getitem__(self, idx): + img_path, depth_path, valid_mask_path, transp_mask_path = self.samples[idx] + + img = np.asarray(Image.open(img_path), dtype=np.float32) / 255.0 + depth = np.asarray(Image.open(depth_path), + dtype=np.uint16).astype('float')*50.0/65535 + + mask_valid = np.asarray(Image.open(valid_mask_path)) + mask_transp = np.asarray(Image.open(transp_mask_path)) + + # depth = depth * mask_valid * mask_transp + depth = np.where(mask_valid * mask_transp, depth, -1) + + img = torch.from_numpy(img).permute(2, 0, 1) + img = self.normalize(img) + depth = torch.from_numpy(depth).unsqueeze(0) + return dict(image=img, depth=depth, image_path=img_path, depth_path=depth_path, dataset='ibims') + + def __len__(self): + return len(self.samples) + + +def get_ibims_loader(config, batch_size=1, **kwargs): + dataloader = DataLoader(iBims(config), batch_size=batch_size, **kwargs) + return dataloader diff --git a/src/flux/annotator/zoe/zoedepth/data/preprocess.py b/src/flux/annotator/zoe/zoedepth/data/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..e08cc309dc823ae6efd7cda8db9eb37130dc5499 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/preprocess.py @@ -0,0 +1,154 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import numpy as np +from dataclasses import dataclass +from typing import Tuple, List + +# dataclass to store the crop parameters +@dataclass +class CropParams: + top: int + bottom: int + left: int + right: int + + + +def get_border_params(rgb_image, tolerance=0.1, cut_off=20, value=0, level_diff_threshold=5, channel_axis=-1, min_border=5) -> CropParams: + gray_image = np.mean(rgb_image, axis=channel_axis) + h, w = gray_image.shape + + + def num_value_pixels(arr): + return np.sum(np.abs(arr - value) < level_diff_threshold) + + def is_above_tolerance(arr, total_pixels): + return (num_value_pixels(arr) / total_pixels) > tolerance + + # Crop top border until number of value pixels become below tolerance + top = min_border + while is_above_tolerance(gray_image[top, :], w) and top < h-1: + top += 1 + if top > cut_off: + break + + # Crop bottom border until number of value pixels become below tolerance + bottom = h - min_border + while is_above_tolerance(gray_image[bottom, :], w) and bottom > 0: + bottom -= 1 + if h - bottom > cut_off: + break + + # Crop left border until number of value pixels become below tolerance + left = min_border + while is_above_tolerance(gray_image[:, left], h) and left < w-1: + left += 1 + if left > cut_off: + break + + # Crop right border until number of value pixels become below tolerance + right = w - min_border + while is_above_tolerance(gray_image[:, right], h) and right > 0: + right -= 1 + if w - right > cut_off: + break + + + return CropParams(top, bottom, left, right) + + +def get_white_border(rgb_image, value=255, **kwargs) -> CropParams: + """Crops the white border of the RGB. + + Args: + rgb: RGB image, shape (H, W, 3). + Returns: + Crop parameters. + """ + if value == 255: + # assert range of values in rgb image is [0, 255] + assert np.max(rgb_image) <= 255 and np.min(rgb_image) >= 0, "RGB image values are not in range [0, 255]." + assert rgb_image.max() > 1, "RGB image values are not in range [0, 255]." + elif value == 1: + # assert range of values in rgb image is [0, 1] + assert np.max(rgb_image) <= 1 and np.min(rgb_image) >= 0, "RGB image values are not in range [0, 1]." + + return get_border_params(rgb_image, value=value, **kwargs) + +def get_black_border(rgb_image, **kwargs) -> CropParams: + """Crops the black border of the RGB. + + Args: + rgb: RGB image, shape (H, W, 3). + + Returns: + Crop parameters. + """ + + return get_border_params(rgb_image, value=0, **kwargs) + +def crop_image(image: np.ndarray, crop_params: CropParams) -> np.ndarray: + """Crops the image according to the crop parameters. + + Args: + image: RGB or depth image, shape (H, W, 3) or (H, W). + crop_params: Crop parameters. + + Returns: + Cropped image. + """ + return image[crop_params.top:crop_params.bottom, crop_params.left:crop_params.right] + +def crop_images(*images: np.ndarray, crop_params: CropParams) -> Tuple[np.ndarray]: + """Crops the images according to the crop parameters. + + Args: + images: RGB or depth images, shape (H, W, 3) or (H, W). + crop_params: Crop parameters. + + Returns: + Cropped images. + """ + return tuple(crop_image(image, crop_params) for image in images) + +def crop_black_or_white_border(rgb_image, *other_images: np.ndarray, tolerance=0.1, cut_off=20, level_diff_threshold=5) -> Tuple[np.ndarray]: + """Crops the white and black border of the RGB and depth images. + + Args: + rgb: RGB image, shape (H, W, 3). This image is used to determine the border. + other_images: The other images to crop according to the border of the RGB image. + Returns: + Cropped RGB and other images. + """ + # crop black border + crop_params = get_black_border(rgb_image, tolerance=tolerance, cut_off=cut_off, level_diff_threshold=level_diff_threshold) + cropped_images = crop_images(rgb_image, *other_images, crop_params=crop_params) + + # crop white border + crop_params = get_white_border(cropped_images[0], tolerance=tolerance, cut_off=cut_off, level_diff_threshold=level_diff_threshold) + cropped_images = crop_images(*cropped_images, crop_params=crop_params) + + return cropped_images + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/data/sun_rgbd_loader.py b/src/flux/annotator/zoe/zoedepth/data/sun_rgbd_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..9e2bdb9aefe68ca4439f41eff3bba722c49fb976 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/sun_rgbd_loader.py @@ -0,0 +1,106 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + return {'image': image, 'depth': depth, 'dataset': "sunrgbd"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class SunRGBD(Dataset): + def __init__(self, data_dir_root): + # test_file_dirs = loadmat(train_test_file)['alltest'].squeeze() + # all_test = [t[0].replace("/n/fs/sun3d/data/", "") for t in test_file_dirs] + # self.all_test = [os.path.join(data_dir_root, t) for t in all_test] + import glob + self.image_files = glob.glob( + os.path.join(data_dir_root, 'rgb', 'rgb', '*')) + self.depth_files = [ + r.replace("rgb/rgb", "gt/gt").replace("jpg", "png") for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.asarray(Image.open(depth_path), dtype='uint16') / 1000.0 + depth[depth > 8] = -1 + depth = depth[..., None] + return self.transform(dict(image=image, depth=depth)) + + def __len__(self): + return len(self.image_files) + + +def get_sunrgbd_loader(data_dir_root, batch_size=1, **kwargs): + dataset = SunRGBD(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) diff --git a/src/flux/annotator/zoe/zoedepth/data/transforms.py b/src/flux/annotator/zoe/zoedepth/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..374416dff24fb4fd55598f3946d6d6b091ddefc9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/transforms.py @@ -0,0 +1,481 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import math +import random + +import cv2 +import numpy as np + + +class RandomFliplr(object): + """Horizontal flip of the sample with given probability. + """ + + def __init__(self, probability=0.5): + """Init. + + Args: + probability (float, optional): Flip probability. Defaults to 0.5. + """ + self.__probability = probability + + def __call__(self, sample): + prob = random.random() + + if prob < self.__probability: + for k, v in sample.items(): + if len(v.shape) >= 2: + sample[k] = np.fliplr(v).copy() + + return sample + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class RandomCrop(object): + """Get a random crop of the sample with the given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_if_needed=False, + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): output width + height (int): output height + resize_if_needed (bool, optional): If True, sample might be upsampled to ensure + that a crop of size (width, height) is possbile. Defaults to False. + """ + self.__size = (height, width) + self.__resize_if_needed = resize_if_needed + self.__image_interpolation_method = image_interpolation_method + + def __call__(self, sample): + + shape = sample["disparity"].shape + + if self.__size[0] > shape[0] or self.__size[1] > shape[1]: + if self.__resize_if_needed: + shape = apply_min_size( + sample, self.__size, self.__image_interpolation_method + ) + else: + raise Exception( + "Output size {} bigger than input size {}.".format( + self.__size, shape + ) + ) + + offset = ( + np.random.randint(shape[0] - self.__size[0] + 1), + np.random.randint(shape[1] - self.__size[1] + 1), + ) + + for k, v in sample.items(): + if k == "code" or k == "basis": + continue + + if len(sample[k].shape) >= 2: + sample[k] = v[ + offset[0]: offset[0] + self.__size[0], + offset[1]: offset[1] + self.__size[1], + ] + + return sample + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + letter_box=False, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + self.__letter_box = letter_box + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def make_letter_box(self, sample): + top = bottom = (self.__height - sample.shape[0]) // 2 + left = right = (self.__width - sample.shape[1]) // 2 + sample = cv2.copyMakeBorder( + sample, top, bottom, left, right, cv2.BORDER_CONSTANT, None, 0) + return sample + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__letter_box: + sample["image"] = self.make_letter_box(sample["image"]) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if self.__letter_box: + sample["disparity"] = self.make_letter_box( + sample["disparity"]) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, + height), interpolation=cv2.INTER_NEAREST + ) + + if self.__letter_box: + sample["depth"] = self.make_letter_box(sample["depth"]) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if self.__letter_box: + sample["mask"] = self.make_letter_box(sample["mask"]) + + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class ResizeFixed(object): + def __init__(self, size): + self.__size = size + + def __call__(self, sample): + sample["image"] = cv2.resize( + sample["image"], self.__size[::-1], interpolation=cv2.INTER_LINEAR + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], self.__size[::- + 1], interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + self.__size[::-1], + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class Rescale(object): + """Rescale target values to the interval [0, max_val]. + If input is constant, values are set to max_val / 2. + """ + + def __init__(self, max_val=1.0, use_mask=True): + """Init. + + Args: + max_val (float, optional): Max output value. Defaults to 1.0. + use_mask (bool, optional): Only operate on valid pixels (mask == True). Defaults to True. + """ + self.__max_val = max_val + self.__use_mask = use_mask + + def __call__(self, sample): + disp = sample["disparity"] + + if self.__use_mask: + mask = sample["mask"] + else: + mask = np.ones_like(disp, dtype=np.bool) + + if np.sum(mask) == 0: + return sample + + min_val = np.min(disp[mask]) + max_val = np.max(disp[mask]) + + if max_val > min_val: + sample["disparity"][mask] = ( + (disp[mask] - min_val) / (max_val - min_val) * self.__max_val + ) + else: + sample["disparity"][mask] = np.ones_like( + disp[mask]) * self.__max_val / 2.0 + + return sample + + +# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class DepthToDisparity(object): + """Convert depth to disparity. Removes depth from sample. + """ + + def __init__(self, eps=1e-4): + self.__eps = eps + + def __call__(self, sample): + assert "depth" in sample + + sample["mask"][sample["depth"] < self.__eps] = False + + sample["disparity"] = np.zeros_like(sample["depth"]) + sample["disparity"][sample["depth"] >= self.__eps] = ( + 1.0 / sample["depth"][sample["depth"] >= self.__eps] + ) + + del sample["depth"] + + return sample + + +class DisparityToDepth(object): + """Convert disparity to depth. Removes disparity from sample. + """ + + def __init__(self, eps=1e-4): + self.__eps = eps + + def __call__(self, sample): + assert "disparity" in sample + + disp = np.abs(sample["disparity"]) + sample["mask"][disp < self.__eps] = False + + # print(sample["disparity"]) + # print(sample["mask"].sum()) + # exit() + + sample["depth"] = np.zeros_like(disp) + sample["depth"][disp >= self.__eps] = ( + 1.0 / disp[disp >= self.__eps] + ) + + del sample["disparity"] + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/src/flux/annotator/zoe/zoedepth/data/vkitti.py b/src/flux/annotator/zoe/zoedepth/data/vkitti.py new file mode 100644 index 0000000000000000000000000000000000000000..72a2e5a8346f6e630ede0e28d6959725af8d7e72 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/vkitti.py @@ -0,0 +1,151 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +import os + +from PIL import Image +import numpy as np +import cv2 + + +class ToTensor(object): + def __init__(self): + self.normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + # self.resize = transforms.Resize((375, 1242)) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + # image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "vkitti"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class VKITTI(Dataset): + def __init__(self, data_dir_root, do_kb_crop=True): + import glob + # image paths are of the form /{HR, LR}//{color, depth_filled}/*.png + self.image_files = glob.glob(os.path.join( + data_dir_root, "test_color", '*.png')) + self.depth_files = [r.replace("test_color", "test_depth") + for r in self.image_files] + self.do_kb_crop = True + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = Image.open(image_path) + depth = Image.open(depth_path) + depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | + cv2.IMREAD_ANYDEPTH) + print("dpeth min max", depth.min(), depth.max()) + + # print(np.shape(image)) + # print(np.shape(depth)) + + # depth[depth > 8] = -1 + + if self.do_kb_crop and False: + height = image.height + width = image.width + top_margin = int(height - 352) + left_margin = int((width - 1216) / 2) + depth = depth.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + image = image.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + # uv = uv[:, top_margin:top_margin + 352, left_margin:left_margin + 1216] + + image = np.asarray(image, dtype=np.float32) / 255.0 + # depth = np.asarray(depth, dtype=np.uint16) /1. + depth = depth[..., None] + sample = dict(image=image, depth=depth) + + # return sample + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_vkitti_loader(data_dir_root, batch_size=1, **kwargs): + dataset = VKITTI(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + + +if __name__ == "__main__": + loader = get_vkitti_loader( + data_dir_root="/home/bhatsf/shortcuts/datasets/vkitti_test") + print("Total files", len(loader.dataset)) + for i, sample in enumerate(loader): + print(sample["image"].shape) + print(sample["depth"].shape) + print(sample["dataset"]) + print(sample['depth'].min(), sample['depth'].max()) + if i > 5: + break diff --git a/src/flux/annotator/zoe/zoedepth/data/vkitti2.py b/src/flux/annotator/zoe/zoedepth/data/vkitti2.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcfb0414b7f3f21859f30ae34bd71689516a3e7 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/data/vkitti2.py @@ -0,0 +1,187 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x: x + # self.resize = transforms.Resize((375, 1242)) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + # image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "vkitti"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class VKITTI2(Dataset): + def __init__(self, data_dir_root, do_kb_crop=True, split="test"): + import glob + + # image paths are of the form /rgb///frames//Camera<0,1>/rgb_{}.jpg + self.image_files = glob.glob(os.path.join( + data_dir_root, "rgb", "**", "frames", "rgb", "Camera_0", '*.jpg'), recursive=True) + self.depth_files = [r.replace("/rgb/", "/depth/").replace( + "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files] + self.do_kb_crop = True + self.transform = ToTensor() + + # If train test split is not created, then create one. + # Split is such that 8% of the frames from each scene are used for testing. + if not os.path.exists(os.path.join(data_dir_root, "train.txt")): + import random + scenes = set([os.path.basename(os.path.dirname( + os.path.dirname(os.path.dirname(f)))) for f in self.image_files]) + train_files = [] + test_files = [] + for scene in scenes: + scene_files = [f for f in self.image_files if os.path.basename( + os.path.dirname(os.path.dirname(os.path.dirname(f)))) == scene] + random.shuffle(scene_files) + train_files.extend(scene_files[:int(len(scene_files) * 0.92)]) + test_files.extend(scene_files[int(len(scene_files) * 0.92):]) + with open(os.path.join(data_dir_root, "train.txt"), "w") as f: + f.write("\n".join(train_files)) + with open(os.path.join(data_dir_root, "test.txt"), "w") as f: + f.write("\n".join(test_files)) + + if split == "train": + with open(os.path.join(data_dir_root, "train.txt"), "r") as f: + self.image_files = f.read().splitlines() + self.depth_files = [r.replace("/rgb/", "/depth/").replace( + "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files] + elif split == "test": + with open(os.path.join(data_dir_root, "test.txt"), "r") as f: + self.image_files = f.read().splitlines() + self.depth_files = [r.replace("/rgb/", "/depth/").replace( + "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files] + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = Image.open(image_path) + # depth = Image.open(depth_path) + depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | + cv2.IMREAD_ANYDEPTH) / 100.0 # cm to m + depth = Image.fromarray(depth) + # print("dpeth min max", depth.min(), depth.max()) + + # print(np.shape(image)) + # print(np.shape(depth)) + + if self.do_kb_crop: + if idx == 0: + print("Using KB input crop") + height = image.height + width = image.width + top_margin = int(height - 352) + left_margin = int((width - 1216) / 2) + depth = depth.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + image = image.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + # uv = uv[:, top_margin:top_margin + 352, left_margin:left_margin + 1216] + + image = np.asarray(image, dtype=np.float32) / 255.0 + # depth = np.asarray(depth, dtype=np.uint16) /1. + depth = np.asarray(depth, dtype=np.float32) / 1. + depth[depth > 80] = -1 + + depth = depth[..., None] + sample = dict(image=image, depth=depth) + + # return sample + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_vkitti2_loader(data_dir_root, batch_size=1, **kwargs): + dataset = VKITTI2(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + + +if __name__ == "__main__": + loader = get_vkitti2_loader( + data_dir_root="/home/bhatsf/shortcuts/datasets/vkitti2") + print("Total files", len(loader.dataset)) + for i, sample in enumerate(loader): + print(sample["image"].shape) + print(sample["depth"].shape) + print(sample["dataset"]) + print(sample['depth'].min(), sample['depth'].max()) + if i > 5: + break diff --git a/src/flux/annotator/zoe/zoedepth/models/__init__.py b/src/flux/annotator/zoe/zoedepth/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/__init__.py b/src/flux/annotator/zoe/zoedepth/models/base_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas.py new file mode 100644 index 0000000000000000000000000000000000000000..ee660bc93d44c28efe8d8c674e715ea2ecb4c183 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas.py @@ -0,0 +1,379 @@ +# MIT License +import os + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn +import numpy as np +from torchvision.transforms import Normalize + + +def denormalize(x): + """Reverses the imagenet normalization applied to the input. + + Args: + x (torch.Tensor - shape(N,3,H,W)): input tensor + + Returns: + torch.Tensor - shape(N,3,H,W): Denormalized input + """ + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device) + return x * std + mean + +def get_activation(name, bank): + def hook(model, input, output): + bank[name] = output + return hook + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + ): + """Init. + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + print("Params passed to Resize transform:") + print("\twidth: ", width) + print("\theight: ", height) + print("\tresize_target: ", resize_target) + print("\tkeep_aspect_ratio: ", keep_aspect_ratio) + print("\tensure_multiple_of: ", ensure_multiple_of) + print("\tresize_method: ", resize_method) + + self.__width = width + self.__height = height + + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, x): + width, height = self.get_size(*x.shape[-2:][::-1]) + return nn.functional.interpolate(x, (height, width), mode='bilinear', align_corners=True) + +class PrepForMidas(object): + def __init__(self, resize_mode="minimal", keep_aspect_ratio=True, img_size=384, do_resize=True): + if isinstance(img_size, int): + img_size = (img_size, img_size) + net_h, net_w = img_size + self.normalization = Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + self.resizer = Resize(net_w, net_h, keep_aspect_ratio=keep_aspect_ratio, ensure_multiple_of=32, resize_method=resize_mode) \ + if do_resize else nn.Identity() + + def __call__(self, x): + return self.normalization(self.resizer(x)) + + +class MidasCore(nn.Module): + def __init__(self, midas, trainable=False, fetch_features=True, layer_names=('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'), freeze_bn=False, keep_aspect_ratio=True, + img_size=384, **kwargs): + """Midas Base model used for multi-scale feature extraction. + + Args: + midas (torch.nn.Module): Midas model. + trainable (bool, optional): Train midas model. Defaults to False. + fetch_features (bool, optional): Extract multi-scale features. Defaults to True. + layer_names (tuple, optional): Layers used for feature extraction. Order = (head output features, last layer features, ...decoder features). Defaults to ('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'). + freeze_bn (bool, optional): Freeze BatchNorm. Generally results in better finetuning performance. Defaults to False. + keep_aspect_ratio (bool, optional): Keep the aspect ratio of input images while resizing. Defaults to True. + img_size (int, tuple, optional): Input resolution. Defaults to 384. + """ + super().__init__() + self.core = midas + self.output_channels = None + self.core_out = {} + self.trainable = trainable + self.fetch_features = fetch_features + # midas.scratch.output_conv = nn.Identity() + self.handles = [] + # self.layer_names = ['out_conv','l4_rn', 'r4', 'r3', 'r2', 'r1'] + self.layer_names = layer_names + + self.set_trainable(trainable) + self.set_fetch_features(fetch_features) + + self.prep = PrepForMidas(keep_aspect_ratio=keep_aspect_ratio, + img_size=img_size, do_resize=kwargs.get('do_resize', True)) + + if freeze_bn: + self.freeze_bn() + + def set_trainable(self, trainable): + self.trainable = trainable + if trainable: + self.unfreeze() + else: + self.freeze() + return self + + def set_fetch_features(self, fetch_features): + self.fetch_features = fetch_features + if fetch_features: + if len(self.handles) == 0: + self.attach_hooks(self.core) + else: + self.remove_hooks() + return self + + def freeze(self): + for p in self.parameters(): + p.requires_grad = False + self.trainable = False + return self + + def unfreeze(self): + for p in self.parameters(): + p.requires_grad = True + self.trainable = True + return self + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + return self + + def forward(self, x, denorm=False, return_rel_depth=False): + with torch.no_grad(): + if denorm: + x = denormalize(x) + x = self.prep(x) + # print("Shape after prep: ", x.shape) + + with torch.set_grad_enabled(self.trainable): + + # print("Input size to Midascore", x.shape) + rel_depth = self.core(x) + # print("Output from midas shape", rel_depth.shape) + if not self.fetch_features: + return rel_depth + out = [self.core_out[k] for k in self.layer_names] + + if return_rel_depth: + return rel_depth, out + return out + + def get_rel_pos_params(self): + for name, p in self.core.pretrained.named_parameters(): + if "relative_position" in name: + yield p + + def get_enc_params_except_rel_pos(self): + for name, p in self.core.pretrained.named_parameters(): + if "relative_position" not in name: + yield p + + def freeze_encoder(self, freeze_rel_pos=False): + if freeze_rel_pos: + for p in self.core.pretrained.parameters(): + p.requires_grad = False + else: + for p in self.get_enc_params_except_rel_pos(): + p.requires_grad = False + return self + + def attach_hooks(self, midas): + if len(self.handles) > 0: + self.remove_hooks() + if "out_conv" in self.layer_names: + self.handles.append(list(midas.scratch.output_conv.children())[ + 3].register_forward_hook(get_activation("out_conv", self.core_out))) + if "r4" in self.layer_names: + self.handles.append(midas.scratch.refinenet4.register_forward_hook( + get_activation("r4", self.core_out))) + if "r3" in self.layer_names: + self.handles.append(midas.scratch.refinenet3.register_forward_hook( + get_activation("r3", self.core_out))) + if "r2" in self.layer_names: + self.handles.append(midas.scratch.refinenet2.register_forward_hook( + get_activation("r2", self.core_out))) + if "r1" in self.layer_names: + self.handles.append(midas.scratch.refinenet1.register_forward_hook( + get_activation("r1", self.core_out))) + if "l4_rn" in self.layer_names: + self.handles.append(midas.scratch.layer4_rn.register_forward_hook( + get_activation("l4_rn", self.core_out))) + + return self + + def remove_hooks(self): + for h in self.handles: + h.remove() + return self + + def __del__(self): + self.remove_hooks() + + def set_output_channels(self, model_type): + self.output_channels = MIDAS_SETTINGS[model_type] + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", train_midas=False, use_pretrained_midas=True, fetch_features=False, freeze_bn=True, force_keep_ar=False, force_reload=False, **kwargs): + if midas_model_type not in MIDAS_SETTINGS: + raise ValueError( + f"Invalid model type: {midas_model_type}. Must be one of {list(MIDAS_SETTINGS.keys())}") + if "img_size" in kwargs: + kwargs = MidasCore.parse_img_size(kwargs) + img_size = kwargs.pop("img_size", [384, 384]) + print("img_size", img_size) + midas_path = os.path.join(os.path.dirname(__file__), 'midas_repo') + midas = torch.hub.load(midas_path, midas_model_type, + pretrained=use_pretrained_midas, force_reload=force_reload, source='local') + kwargs.update({'keep_aspect_ratio': force_keep_ar}) + midas_core = MidasCore(midas, trainable=train_midas, fetch_features=fetch_features, + freeze_bn=freeze_bn, img_size=img_size, **kwargs) + midas_core.set_output_channels(midas_model_type) + return midas_core + + @staticmethod + def build_from_config(config): + return MidasCore.build(**config) + + @staticmethod + def parse_img_size(config): + assert 'img_size' in config + if isinstance(config['img_size'], str): + assert "," in config['img_size'], "img_size should be a string with comma separated img_size=H,W" + config['img_size'] = list(map(int, config['img_size'].split(","))) + assert len( + config['img_size']) == 2, "img_size should be a string with comma separated img_size=H,W" + elif isinstance(config['img_size'], int): + config['img_size'] = [config['img_size'], config['img_size']] + else: + assert isinstance(config['img_size'], list) and len( + config['img_size']) == 2, "img_size should be a list of H,W" + return config + + +nchannels2models = { + tuple([256]*5): ["DPT_BEiT_L_384", "DPT_BEiT_L_512", "DPT_BEiT_B_384", "DPT_SwinV2_L_384", "DPT_SwinV2_B_384", "DPT_SwinV2_T_256", "DPT_Large", "DPT_Hybrid"], + (512, 256, 128, 64, 64): ["MiDaS_small"] +} + +# Model name to number of output channels +MIDAS_SETTINGS = {m: k for k, v in nchannels2models.items() + for m in v + } diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/.gitignore b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a13c80028de3d297de4a3f09cee1b20759acc006 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/.gitignore @@ -0,0 +1,110 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +*.png +*.pfm +*.jpg +*.jpeg +*.pt \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/Dockerfile b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..466bc94ba3128ea9cbe4bde82bd2fd1fc9daa8af --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/Dockerfile @@ -0,0 +1,29 @@ +# enables cuda support in docker +FROM nvidia/cuda:10.2-cudnn7-runtime-ubuntu18.04 + +# install python 3.6, pip and requirements for opencv-python +# (see https://github.com/NVIDIA/nvidia-docker/issues/864) +RUN apt-get update && apt-get -y install \ + python3 \ + python3-pip \ + libsm6 \ + libxext6 \ + libxrender-dev \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# install python dependencies +RUN pip3 install --upgrade pip +RUN pip3 install torch~=1.8 torchvision opencv-python-headless~=3.4 timm + +# copy inference code +WORKDIR /opt/MiDaS +COPY ./midas ./midas +COPY ./*.py ./ + +# download model weights so the docker image can be used offline +RUN cd weights && {curl -OL https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt; cd -; } +RUN python3 run.py --model_type dpt_hybrid; exit 0 + +# entrypoint (dont forget to mount input and output directories) +CMD python3 run.py --model_type dpt_hybrid diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/LICENSE b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..277b5c11be103f028a8d10985139f1da10c2f08e --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/README.md b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9568ea71c755b6938ee5482ba9f09be722e75943 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/README.md @@ -0,0 +1,259 @@ +## Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer + +This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3): + +>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + + +and our [preprint](https://arxiv.org/abs/2103.13413): + +> Vision Transformers for Dense Prediction +> René Ranftl, Alexey Bochkovskiy, Vladlen Koltun + + +MiDaS was trained on up to 12 datasets (ReDWeb, DIML, Movies, MegaDepth, WSVD, TartanAir, HRWSI, ApolloScape, BlendedMVS, IRS, KITTI, NYU Depth V2) with +multi-objective optimization. +The original model that was trained on 5 datasets (`MIX 5` in the paper) can be found [here](https://github.com/isl-org/MiDaS/releases/tag/v2). +The figure below shows an overview of the different MiDaS models; the bubble size scales with number of parameters. + +![](figures/Improvement_vs_FPS.png) + +### Setup + +1) Pick one or more models and download the corresponding weights to the `weights` folder: + +MiDaS 3.1 +- For highest quality: [dpt_beit_large_512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) +- For moderately less quality, but better speed-performance trade-off: [dpt_swin2_large_384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt) +- For embedded devices: [dpt_swin2_tiny_256](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt), [dpt_levit_224](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt) +- For inference on Intel CPUs, OpenVINO may be used for the small legacy model: openvino_midas_v21_small [.xml](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.xml), [.bin](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.bin) + +MiDaS 3.0: Legacy transformer models [dpt_large_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) and [dpt_hybrid_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) + +MiDaS 2.1: Legacy convolutional models [midas_v21_384](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) and [midas_v21_small_256](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt) + +1) Set up dependencies: + + ```shell + conda env create -f environment.yaml + conda activate midas-py310 + ``` + +#### optional + +For the Next-ViT model, execute + +```shell +git submodule add https://github.com/isl-org/Next-ViT midas/external/next_vit +``` + +For the OpenVINO model, install + +```shell +pip install openvino +``` + +### Usage + +1) Place one or more input images in the folder `input`. + +2) Run the model with + + ```shell + python run.py --model_type --input_path input --output_path output + ``` + where `````` is chosen from [dpt_beit_large_512](#model_type), [dpt_beit_large_384](#model_type), + [dpt_beit_base_384](#model_type), [dpt_swin2_large_384](#model_type), [dpt_swin2_base_384](#model_type), + [dpt_swin2_tiny_256](#model_type), [dpt_swin_large_384](#model_type), [dpt_next_vit_large_384](#model_type), + [dpt_levit_224](#model_type), [dpt_large_384](#model_type), [dpt_hybrid_384](#model_type), + [midas_v21_384](#model_type), [midas_v21_small_256](#model_type), [openvino_midas_v21_small_256](#model_type). + +3) The resulting depth maps are written to the `output` folder. + +#### optional + +1) By default, the inference resizes the height of input images to the size of a model to fit into the encoder. This + size is given by the numbers in the model names of the [accuracy table](#accuracy). Some models do not only support a single + inference height but a range of different heights. Feel free to explore different heights by appending the extra + command line argument `--height`. Unsupported height values will throw an error. Note that using this argument may + decrease the model accuracy. +2) By default, the inference keeps the aspect ratio of input images when feeding them into the encoder if this is + supported by a model (all models except for Swin, Swin2, LeViT). In order to resize to a square resolution, + disregarding the aspect ratio while preserving the height, use the command line argument `--square`. + +#### via Camera + + If you want the input images to be grabbed from the camera and shown in a window, leave the input and output paths + away and choose a model type as shown above: + + ```shell + python run.py --model_type --side + ``` + + The argument `--side` is optional and causes both the input RGB image and the output depth map to be shown + side-by-side for comparison. + +#### via Docker + +1) Make sure you have installed Docker and the + [NVIDIA Docker runtime](https://github.com/NVIDIA/nvidia-docker/wiki/Installation-\(Native-GPU-Support\)). + +2) Build the Docker image: + + ```shell + docker build -t midas . + ``` + +3) Run inference: + + ```shell + docker run --rm --gpus all -v $PWD/input:/opt/MiDaS/input -v $PWD/output:/opt/MiDaS/output -v $PWD/weights:/opt/MiDaS/weights midas + ``` + + This command passes through all of your NVIDIA GPUs to the container, mounts the + `input` and `output` directories and then runs the inference. + +#### via PyTorch Hub + +The pretrained model is also available on [PyTorch Hub](https://pytorch.org/hub/intelisl_midas_v2/) + +#### via TensorFlow or ONNX + +See [README](https://github.com/isl-org/MiDaS/tree/master/tf) in the `tf` subdirectory. + +Currently only supports MiDaS v2.1. + + +#### via Mobile (iOS / Android) + +See [README](https://github.com/isl-org/MiDaS/tree/master/mobile) in the `mobile` subdirectory. + +#### via ROS1 (Robot Operating System) + +See [README](https://github.com/isl-org/MiDaS/tree/master/ros) in the `ros` subdirectory. + +Currently only supports MiDaS v2.1. DPT-based models to be added. + + +### Accuracy + +We provide a **zero-shot error** $\epsilon_d$ which is evaluated for 6 different datasets +(see [paper](https://arxiv.org/abs/1907.01341v3)). **Lower error values are better**. +$\color{green}{\textsf{Overall model quality is represented by the improvement}}$ ([Imp.](#improvement)) with respect to +MiDaS 3.0 DPTL-384. The models are grouped by the height used for inference, whereas the square training resolution is given by +the numbers in the model names. The table also shows the **number of parameters** (in millions) and the +**frames per second** for inference at the training resolution (for GPU RTX 3090): + +| MiDaS Model | DIW
WHDR | Eth3d
AbsRel | Sintel
AbsRel | TUM
δ1 | KITTI
δ1 | NYUv2
δ1 | $\color{green}{\textsf{Imp.}}$
% | Par.
M | FPS
  | +|-----------------------------------------------------------------------------------------------------------------------|-------------------------:|-----------------------------:|------------------------------:|-------------------------:|-------------------------:|-------------------------:|-------------------------------------------------:|----------------------:|--------------------------:| +| **Inference height 512** | | | | | | | | | | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1137 | 0.0659 | 0.2366 | **6.13** | 11.56* | **1.86*** | $\color{green}{\textsf{19}}$ | **345** | **5.7** | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt)$\tiny{\square}$ | **0.1121** | **0.0614** | **0.2090** | 6.46 | **5.00*** | 1.90* | $\color{green}{\textsf{34}}$ | **345** | **5.7** | +| | | | | | | | | | | +| **Inference height 384** | | | | | | | | | | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1245 | 0.0681 | **0.2176** | **6.13** | 6.28* | **2.16*** | $\color{green}{\textsf{28}}$ | 345 | 12 | +| [v3.1 Swin2L-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt)$\tiny{\square}$ | 0.1106 | 0.0732 | 0.2442 | 8.87 | **5.84*** | 2.92* | $\color{green}{\textsf{22}}$ | 213 | 41 | +| [v3.1 Swin2B-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt)$\tiny{\square}$ | 0.1095 | 0.0790 | 0.2404 | 8.93 | 5.97* | 3.28* | $\color{green}{\textsf{22}}$ | 102 | 39 | +| [v3.1 SwinL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt)$\tiny{\square}$ | 0.1126 | 0.0853 | 0.2428 | 8.74 | 6.60* | 3.34* | $\color{green}{\textsf{17}}$ | 213 | 49 | +| [v3.1 BEiTL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt) | 0.1239 | **0.0667** | 0.2545 | 7.17 | 9.84* | 2.21* | $\color{green}{\textsf{17}}$ | 344 | 13 | +| [v3.1 Next-ViTL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt) | **0.1031** | 0.0954 | 0.2295 | 9.21 | 6.89* | 3.47* | $\color{green}{\textsf{16}}$ | **72** | 30 | +| [v3.1 BEiTB-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt) | 0.1159 | 0.0967 | 0.2901 | 9.88 | 26.60* | 3.91* | $\color{green}{\textsf{-31}}$ | 112 | 31 | +| [v3.0 DPTL-384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) | 0.1082 | 0.0888 | 0.2697 | 9.97 | 8.46 | 8.32 | $\color{green}{\textsf{0}}$ | 344 | **61** | +| [v3.0 DPTH-384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) | 0.1106 | 0.0934 | 0.2741 | 10.89 | 11.56 | 8.69 | $\color{green}{\textsf{-10}}$ | 123 | 50 | +| [v2.1 Large384](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) | 0.1295 | 0.1155 | 0.3285 | 12.51 | 16.08 | 8.71 | $\color{green}{\textsf{-32}}$ | 105 | 47 | +| | | | | | | | | | | +| **Inference height 256** | | | | | | | | | | +| [v3.1 Swin2T-256](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt)$\tiny{\square}$ | **0.1211** | **0.1106** | **0.2868** | **13.43** | **10.13*** | **5.55*** | $\color{green}{\textsf{-11}}$ | 42 | 64 | +| [v2.1 Small256](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt) | 0.1344 | 0.1344 | 0.3370 | 14.53 | 29.27 | 13.43 | $\color{green}{\textsf{-76}}$ | **21** | **90** | +| | | | | | | | | | | +| **Inference height 224** | | | | | | | | | | +| [v3.1 LeViT224](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt)$\tiny{\square}$ | **0.1314** | **0.1206** | **0.3148** | **18.21** | **15.27*** | **8.64*** | $\color{green}{\textsf{-40}}$ | **51** | **73** | + +* No zero-shot error, because models are also trained on KITTI and NYU Depth V2\ +$\square$ Validation performed at **square resolution**, either because the transformer encoder backbone of a model +does not support non-square resolutions (Swin, Swin2, LeViT) or for comparison with these models. All other +validations keep the aspect ratio. A difference in resolution limits the comparability of the zero-shot error and the +improvement, because these quantities are averages over the pixels of an image and do not take into account the +advantage of more details due to a higher resolution.\ +Best values per column and same validation height in bold + +#### Improvement + +The improvement in the above table is defined as the relative zero-shot error with respect to MiDaS v3.0 +DPTL-384 and averaging over the datasets. So, if $\epsilon_d$ is the zero-shot error for dataset $d$, then +the $\color{green}{\textsf{improvement}}$ is given by $100(1-(1/6)\sum_d\epsilon_d/\epsilon_{d,\rm{DPT_{L-384}}})$%. + +Note that the improvements of 10% for MiDaS v2.0 → v2.1 and 21% for MiDaS v2.1 → v3.0 are not visible from the +improvement column (Imp.) in the table but would require an evaluation with respect to MiDaS v2.1 Large384 +and v2.0 Large384 respectively instead of v3.0 DPTL-384. + +### Depth map comparison + +Zoom in for better visibility +![](figures/Comparison.png) + +### Speed on Camera Feed + +Test configuration +- Windows 10 +- 11th Gen Intel Core i7-1185G7 3.00GHz +- 16GB RAM +- Camera resolution 640x480 +- openvino_midas_v21_small_256 + +Speed: 22 FPS + +### Changelog + +* [Dec 2022] Released MiDaS v3.1: + - New models based on 5 different types of transformers ([BEiT](https://arxiv.org/pdf/2106.08254.pdf), [Swin2](https://arxiv.org/pdf/2111.09883.pdf), [Swin](https://arxiv.org/pdf/2103.14030.pdf), [Next-ViT](https://arxiv.org/pdf/2207.05501.pdf), [LeViT](https://arxiv.org/pdf/2104.01136.pdf)) + - Training datasets extended from 10 to 12, including also KITTI and NYU Depth V2 using [BTS](https://github.com/cleinc/bts) split + - Best model, BEiTLarge 512, with resolution 512x512, is on average about [28% more accurate](#Accuracy) than MiDaS v3.0 + - Integrated live depth estimation from camera feed +* [Sep 2021] Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/DPT-Large). +* [Apr 2021] Released MiDaS v3.0: + - New models based on [Dense Prediction Transformers](https://arxiv.org/abs/2103.13413) are on average [21% more accurate](#Accuracy) than MiDaS v2.1 + - Additional models can be found [here](https://github.com/isl-org/DPT) +* [Nov 2020] Released MiDaS v2.1: + - New model that was trained on 10 datasets and is on average about [10% more accurate](#Accuracy) than [MiDaS v2.0](https://github.com/isl-org/MiDaS/releases/tag/v2) + - New light-weight model that achieves [real-time performance](https://github.com/isl-org/MiDaS/tree/master/mobile) on mobile platforms. + - Sample applications for [iOS](https://github.com/isl-org/MiDaS/tree/master/mobile/ios) and [Android](https://github.com/isl-org/MiDaS/tree/master/mobile/android) + - [ROS package](https://github.com/isl-org/MiDaS/tree/master/ros) for easy deployment on robots +* [Jul 2020] Added TensorFlow and ONNX code. Added [online demo](http://35.202.76.57/). +* [Dec 2019] Released new version of MiDaS - the new model is significantly more accurate and robust +* [Jul 2019] Initial release of MiDaS ([Link](https://github.com/isl-org/MiDaS/releases/tag/v1)) + +### Citation + +Please cite our paper if you use this code or any of the models: +``` +@ARTICLE {Ranftl2022, + author = "Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun", + title = "Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-Shot Cross-Dataset Transfer", + journal = "IEEE Transactions on Pattern Analysis and Machine Intelligence", + year = "2022", + volume = "44", + number = "3" +} +``` + +If you use a DPT-based model, please also cite: + +``` +@article{Ranftl2021, + author = {Ren\'{e} Ranftl and Alexey Bochkovskiy and Vladlen Koltun}, + title = {Vision Transformers for Dense Prediction}, + journal = {ICCV}, + year = {2021}, +} +``` + +### Acknowledgements + +Our work builds on and uses code from [timm](https://github.com/rwightman/pytorch-image-models) and [Next-ViT](https://github.com/bytedance/Next-ViT). +We'd like to thank the authors for making these libraries available. + +### License + +MIT License diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/environment.yaml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9abe5693b9e0de56b7d20728f4d0e6333c5822d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/environment.yaml @@ -0,0 +1,16 @@ +name: midas-py310 +channels: + - pytorch + - defaults +dependencies: + - nvidia::cudatoolkit=11.7 + - python=3.10.8 + - pytorch::pytorch=1.13.0 + - torchvision=0.14.0 + - pip=22.3.1 + - numpy=1.23.4 + - pip: + - opencv-python==4.6.0.66 + - imutils==0.5.4 + - timm==0.6.12 + - einops==0.6.0 \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/hubconf.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/hubconf.py new file mode 100644 index 0000000000000000000000000000000000000000..0d638be5151c4e305daff0c47d1ea3fc8066377d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/hubconf.py @@ -0,0 +1,435 @@ +dependencies = ["torch"] + +import torch + +from midas.dpt_depth import DPTDepthModel +from midas.midas_net import MidasNet +from midas.midas_net_custom import MidasNet_small + +def DPT_BEiT_L_512(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_L_512 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitl16_512", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_BEiT_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitl16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_BEiT_B_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_B_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitb16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2l24_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_B_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_B_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2b24_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_T_256(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_T_256 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2t16_256", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Swin_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_Swin_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swinl12_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Next_ViT_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_Next_ViT_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="next_vit_large_6m", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_LeViT_224(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_LeViT_224 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="levit_384", + non_negative=True, + head_features_1=64, + head_features_2=8, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Large(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT-Large model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="vitl16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Hybrid(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT-Hybrid model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="vitb_rn50_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def MiDaS(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS v2.1 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = MidasNet() + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def MiDaS_small(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS v2.1 small model for monocular depth estimation on resource-constrained devices + pretrained (bool): load pretrained weights into model + """ + + model = MidasNet_small(None, features=64, backbone="efficientnet_lite3", exportable=True, non_negative=True, blocks={'expand': True}) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + + +def transforms(): + import cv2 + from torchvision.transforms import Compose + from midas.transforms import Resize, NormalizeImage, PrepareForNet + from midas import transforms + + transforms.default_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.small_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 256, + 256, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.dpt_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.beit512_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 512, + 512, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.swin384_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.swin256_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 256, + 256, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.levit_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 224, + 224, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + return transforms diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/input/.placeholder b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/input/.placeholder new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py new file mode 100644 index 0000000000000000000000000000000000000000..e616dfd4026f448f9e22d35c6ad8b0028732acb9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py @@ -0,0 +1,196 @@ +import timm +import torch +import types + +import numpy as np +import torch.nn.functional as F + +from .utils import forward_adapted_unflatten, make_backbone_default +from timm.models.beit import gen_relative_position_index +from torch.utils.checkpoint import checkpoint +from typing import Optional + + +def forward_beit(pretrained, x): + return forward_adapted_unflatten(pretrained, x, "forward_features") + + +def patch_embed_forward(self, x): + """ + Modification of timm.models.layers.patch_embed.py: PatchEmbed.forward to support arbitrary window sizes. + """ + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +def _get_rel_pos_bias(self, window_size): + """ + Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes. + """ + old_height = 2 * self.window_size[0] - 1 + old_width = 2 * self.window_size[1] - 1 + + new_height = 2 * window_size[0] - 1 + new_width = 2 * window_size[1] - 1 + + old_relative_position_bias_table = self.relative_position_bias_table + + old_num_relative_distance = self.num_relative_distance + new_num_relative_distance = new_height * new_width + 3 + + old_sub_table = old_relative_position_bias_table[:old_num_relative_distance - 3] + + old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2) + new_sub_table = F.interpolate(old_sub_table, size=(new_height, new_width), mode="bilinear") + new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1) + + new_relative_position_bias_table = torch.cat( + [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3:]]) + + key = str(window_size[1]) + "," + str(window_size[0]) + if key not in self.relative_position_indices.keys(): + self.relative_position_indices[key] = gen_relative_position_index(window_size) + + relative_position_bias = new_relative_position_bias_table[ + self.relative_position_indices[key].view(-1)].view( + window_size[0] * window_size[1] + 1, + window_size[0] * window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + return relative_position_bias.unsqueeze(0) + + +def attention_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None): + """ + Modification of timm.models.beit.py: Attention.forward to support arbitrary window sizes. + """ + B, N, C = x.shape + + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + window_size = tuple(np.array(resolution) // 16) + attn = attn + self._get_rel_pos_bias(window_size) + if shared_rel_pos_bias is not None: + attn = attn + shared_rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None): + """ + Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes. + """ + if self.gamma_1 is None: + x = x + self.drop_path1(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias)) + x = x + self.drop_path2(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), resolution, + shared_rel_pos_bias=shared_rel_pos_bias)) + x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +def beit_forward_features(self, x): + """ + Modification of timm.models.beit.py: Beit.forward_features to support arbitrary window sizes. + """ + resolution = x.shape[2:] + + x = self.patch_embed(x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias) + else: + x = blk(x, resolution, shared_rel_pos_bias=rel_pos_bias) + x = self.norm(x) + return x + + +def _make_beit_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[0, 4, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + backbone = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index, + start_index_readout) + + backbone.model.patch_embed.forward = types.MethodType(patch_embed_forward, backbone.model.patch_embed) + backbone.model.forward_features = types.MethodType(beit_forward_features, backbone.model) + + for block in backbone.model.blocks: + attn = block.attn + attn._get_rel_pos_bias = types.MethodType(_get_rel_pos_bias, attn) + attn.forward = types.MethodType(attention_forward, attn) + attn.relative_position_indices = {} + + block.forward = types.MethodType(block_forward, block) + + return backbone + + +def _make_pretrained_beitl16_512(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_large_patch16_512", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks is None else hooks + + features = [256, 512, 1024, 1024] + + return _make_beit_backbone( + model, + features=features, + size=[512, 512], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_beitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks is None else hooks + return _make_beit_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_beitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks is None else hooks + return _make_beit_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + ) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py new file mode 100644 index 0000000000000000000000000000000000000000..6d023a98702a0451806d26f33f8bccf931814f10 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py @@ -0,0 +1,106 @@ +import timm +import torch +import torch.nn as nn +import numpy as np + +from .utils import activations, get_activation, Transpose + + +def forward_levit(pretrained, x): + pretrained.model.forward_features(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + + layer_1 = pretrained.act_postprocess1(layer_1) + layer_2 = pretrained.act_postprocess2(layer_2) + layer_3 = pretrained.act_postprocess3(layer_3) + + return layer_1, layer_2, layer_3 + + +def _make_levit_backbone( + model, + hooks=[3, 11, 21], + patch_grid=[14, 14] +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + + pretrained.activations = activations + + patch_grid_size = np.array(patch_grid, dtype=int) + + pretrained.act_postprocess1 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) + ) + pretrained.act_postprocess2 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist())) + ) + pretrained.act_postprocess3 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist())) + ) + + return pretrained + + +class ConvTransposeNorm(nn.Sequential): + """ + Modification of + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm + such that ConvTranspose2d is used instead of Conv2d. + """ + + def __init__( + self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', + nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False)) + self.add_module('bn', nn.BatchNorm2d(out_chs)) + + nn.init.constant_(self.bn.weight, bn_weight_init) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = nn.ConvTranspose2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def stem_b4_transpose(in_chs, out_chs, activation): + """ + Modification of + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16 + such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half. + """ + return nn.Sequential( + ConvTransposeNorm(in_chs, out_chs, 3, 2, 1), + activation(), + ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1), + activation()) + + +def _make_pretrained_levit_384(pretrained, hooks=None): + model = timm.create_model("levit_384", pretrained=pretrained) + + hooks = [3, 11, 21] if hooks == None else hooks + return _make_levit_backbone( + model, + hooks=hooks + ) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..8afdd8b743b5ab023a359dc3b721e601b1a40d11 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py @@ -0,0 +1,39 @@ +import timm + +import torch.nn as nn + +from pathlib import Path +from .utils import activations, forward_default, get_activation + +from ..external.next_vit.classification.nextvit import * + + +def forward_next_vit(pretrained, x): + return forward_default(pretrained, x, "forward") + + +def _make_next_vit_backbone( + model, + hooks=[2, 6, 36, 39], +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + return pretrained + + +def _make_pretrained_next_vit_large_6m(hooks=None): + model = timm.create_model("nextvit_large") + + hooks = [2, 6, 36, 39] if hooks == None else hooks + return _make_next_vit_backbone( + model, + hooks=hooks, + ) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py new file mode 100644 index 0000000000000000000000000000000000000000..f8c71367e3e78b087f80b2ab3e2f495a9c372f1a --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py @@ -0,0 +1,13 @@ +import timm + +from .swin_common import _make_swin_backbone + + +def _make_pretrained_swinl12_384(pretrained, hooks=None): + model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4c8f1d6fc1807a207dc6b9a261c6f7b14a87a3 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py @@ -0,0 +1,34 @@ +import timm + +from .swin_common import _make_swin_backbone + + +def _make_pretrained_swin2l24_384(pretrained, hooks=None): + model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) + + +def _make_pretrained_swin2b24_384(pretrained, hooks=None): + model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) + + +def _make_pretrained_swin2t16_256(pretrained, hooks=None): + model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained) + + hooks = [1, 1, 5, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks, + patch_grid=[64, 64] + ) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py new file mode 100644 index 0000000000000000000000000000000000000000..94d63d408f18511179d90b3ac6f697385d1e556d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py @@ -0,0 +1,52 @@ +import torch + +import torch.nn as nn +import numpy as np + +from .utils import activations, forward_default, get_activation, Transpose + + +def forward_swin(pretrained, x): + return forward_default(pretrained, x) + + +def _make_swin_backbone( + model, + hooks=[1, 1, 17, 1], + patch_grid=[96, 96] +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + if hasattr(model, "patch_grid"): + used_patch_grid = model.patch_grid + else: + used_patch_grid = patch_grid + + patch_grid_size = np.array(used_patch_grid, dtype=int) + + pretrained.act_postprocess1 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) + ) + pretrained.act_postprocess2 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist())) + ) + pretrained.act_postprocess3 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist())) + ) + pretrained.act_postprocess4 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist())) + ) + + return pretrained diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0558899dddcfccec5f01a764d4f21738eb612149 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py @@ -0,0 +1,249 @@ +import torch + +import torch.nn as nn + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index:] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index:] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:]) + features = torch.cat((x[:, self.start_index:], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def forward_default(pretrained, x, function_name="forward_features"): + exec(f"pretrained.model.{function_name}(x)") + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + if hasattr(pretrained, "act_postprocess1"): + layer_1 = pretrained.act_postprocess1(layer_1) + if hasattr(pretrained, "act_postprocess2"): + layer_2 = pretrained.act_postprocess2(layer_2) + if hasattr(pretrained, "act_postprocess3"): + layer_3 = pretrained.act_postprocess3(layer_3) + if hasattr(pretrained, "act_postprocess4"): + layer_4 = pretrained.act_postprocess4(layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def forward_adapted_unflatten(pretrained, x, function_name="forward_features"): + b, c, h, w = x.shape + + exec(f"glob = pretrained.model.{function_name}(x)") + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3: len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3: len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3: len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3: len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def make_backbone_default( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index_readout) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + return pretrained diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..413f9693bd4548342280e329c9128c1a52cea920 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py @@ -0,0 +1,221 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + +from .utils import (activations, forward_adapted_unflatten, get_activation, get_readout_oper, + make_backbone_default, Transpose) + + +def forward_vit(pretrained, x): + return forward_adapted_unflatten(pretrained, x, "forward_flex") + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index:], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + if self.no_embed_class: + x = x + pos_embed + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + if not self.no_embed_class: + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + pretrained = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index, + start_index_readout) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + patch_size=[16, 16], + number_stages=2, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + used_number_stages = 0 if use_vit_only else number_stages + for s in range(used_number_stages): + pretrained.model.patch_embed.backbone.stages[s].register_forward_hook( + get_activation(str(s + 1)) + ) + for s in range(used_number_stages, 4): + pretrained.model.blocks[hooks[s]].register_forward_hook(get_activation(str(s + 1))) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + for s in range(used_number_stages): + value = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) + exec(f"pretrained.act_postprocess{s + 1}=value") + for s in range(used_number_stages, 4): + if s < number_stages: + final_layer = nn.ConvTranspose2d( + in_channels=features[s], + out_channels=features[s], + kernel_size=4 // (2 ** s), + stride=4 // (2 ** s), + padding=0, + bias=True, + dilation=1, + groups=1, + ) + elif s > number_stages: + final_layer = nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ) + else: + final_layer = None + + layers = [ + readout_oper[s], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[s], + kernel_size=1, + stride=1, + padding=0, + ), + ] + if final_layer is not None: + layers.append(final_layer) + + value = nn.Sequential(*layers) + exec(f"pretrained.act_postprocess{s + 1}=value") + + pretrained.model.start_index = start_index + pretrained.model.patch_size = patch_size + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..6d87a00680bb6ed9a6d7c3043ea30a1e90361794 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py @@ -0,0 +1,439 @@ +import torch +import torch.nn as nn + +from .backbones.beit import ( + _make_pretrained_beitl16_512, + _make_pretrained_beitl16_384, + _make_pretrained_beitb16_384, + forward_beit, +) +from .backbones.swin_common import ( + forward_swin, +) +from .backbones.swin2 import ( + _make_pretrained_swin2l24_384, + _make_pretrained_swin2b24_384, + _make_pretrained_swin2t16_256, +) +from .backbones.swin import ( + _make_pretrained_swinl12_384, +) +from .backbones.levit import ( + _make_pretrained_levit_384, + forward_levit, +) +from .backbones.vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, + use_vit_only=False, use_readout="ignore", in_features=[96, 256, 512, 1024]): + if backbone == "beitl16_512": + pretrained = _make_pretrained_beitl16_512( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # BEiT_512-L (backbone) + elif backbone == "beitl16_384": + pretrained = _make_pretrained_beitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # BEiT_384-L (backbone) + elif backbone == "beitb16_384": + pretrained = _make_pretrained_beitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # BEiT_384-B (backbone) + elif backbone == "swin2l24_384": + pretrained = _make_pretrained_swin2l24_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [192, 384, 768, 1536], features, groups=groups, expand=expand + ) # Swin2-L/12to24 (backbone) + elif backbone == "swin2b24_384": + pretrained = _make_pretrained_swin2b24_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [128, 256, 512, 1024], features, groups=groups, expand=expand + ) # Swin2-B/12to24 (backbone) + elif backbone == "swin2t16_256": + pretrained = _make_pretrained_swin2t16_256( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # Swin2-T/16 (backbone) + elif backbone == "swinl12_384": + pretrained = _make_pretrained_swinl12_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [192, 384, 768, 1536], features, groups=groups, expand=expand + ) # Swin-L/12 (backbone) + elif backbone == "next_vit_large_6m": + from .backbones.next_vit import _make_pretrained_next_vit_large_6m + pretrained = _make_pretrained_next_vit_large_6m(hooks=hooks) + scratch = _make_scratch( + in_features, features, groups=groups, expand=expand + ) # Next-ViT-L on ImageNet-1K-6M (backbone) + elif backbone == "levit_384": + pretrained = _make_pretrained_levit_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [384, 512, 768], features, groups=groups, expand=expand + ) # LeViT 384 (backbone) + elif backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + if len(in_shape) >= 4: + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..3129d09cb43a7c79b23916236991fabbedb78f55 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_beit, + forward_swin, + forward_levit, + forward_vit, +) +from .backbones.levit import stem_b4_transpose +from timm.models.layers import get_act_layer + + +def _make_fusion_block(features, use_bn, size = None): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + **kwargs + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + # For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the + # hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments. + hooks = { + "beitl16_512": [5, 11, 17, 23], + "beitl16_384": [5, 11, 17, 23], + "beitb16_384": [2, 5, 8, 11], + "swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1] + "swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39] + "levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21] + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + }[backbone] + + if "next_vit" in backbone: + in_features = { + "next_vit_large_6m": [96, 256, 512, 1024], + }[backbone] + else: + in_features = None + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks, + use_readout=readout, + in_features=in_features, + ) + + self.number_layers = len(hooks) if hooks is not None else 4 + size_refinenet3 = None + self.scratch.stem_transpose = None + + if "beit" in backbone: + self.forward_transformer = forward_beit + elif "swin" in backbone: + self.forward_transformer = forward_swin + elif "next_vit" in backbone: + from .backbones.next_vit import forward_next_vit + self.forward_transformer = forward_next_vit + elif "levit" in backbone: + self.forward_transformer = forward_levit + size_refinenet3 = 7 + self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish")) + else: + self.forward_transformer = forward_vit + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3) + if self.number_layers >= 4: + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layers = self.forward_transformer(self.pretrained, x) + if self.number_layers == 3: + layer_1, layer_2, layer_3 = layers + else: + layer_1, layer_2, layer_3, layer_4 = layers + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + if self.number_layers >= 4: + layer_4_rn = self.scratch.layer4_rn(layer_4) + + if self.number_layers == 3: + path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:]) + else: + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + if self.scratch.stem_transpose is not None: + path_1 = self.scratch.stem_transpose(path_1) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features + head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32 + kwargs.pop("head_features_1", None) + kwargs.pop("head_features_2", None) + + head = nn.Sequential( + nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..f1cd1f2d43054bfd3d650587c7b2ed35f1347c9e --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py @@ -0,0 +1,242 @@ +import cv2 +import torch + +from midas.dpt_depth import DPTDepthModel +from midas.midas_net import MidasNet +from midas.midas_net_custom import MidasNet_small +from midas.transforms import Resize, NormalizeImage, PrepareForNet + +from torchvision.transforms import Compose + +default_models = { + "dpt_beit_large_512": "weights/dpt_beit_large_512.pt", + "dpt_beit_large_384": "weights/dpt_beit_large_384.pt", + "dpt_beit_base_384": "weights/dpt_beit_base_384.pt", + "dpt_swin2_large_384": "weights/dpt_swin2_large_384.pt", + "dpt_swin2_base_384": "weights/dpt_swin2_base_384.pt", + "dpt_swin2_tiny_256": "weights/dpt_swin2_tiny_256.pt", + "dpt_swin_large_384": "weights/dpt_swin_large_384.pt", + "dpt_next_vit_large_384": "weights/dpt_next_vit_large_384.pt", + "dpt_levit_224": "weights/dpt_levit_224.pt", + "dpt_large_384": "weights/dpt_large_384.pt", + "dpt_hybrid_384": "weights/dpt_hybrid_384.pt", + "midas_v21_384": "weights/midas_v21_384.pt", + "midas_v21_small_256": "weights/midas_v21_small_256.pt", + "openvino_midas_v21_small_256": "weights/openvino_midas_v21_small_256.xml", +} + + +def load_model(device, model_path, model_type="dpt_large_384", optimize=True, height=None, square=False): + """Load the specified network. + + Args: + device (device): the torch device used + model_path (str): path to saved model + model_type (str): the type of the model to be loaded + optimize (bool): optimize the model to half-integer on CUDA? + height (int): inference encoder image height + square (bool): resize to a square resolution? + + Returns: + The loaded network, the transform which prepares images as input to the network and the dimensions of the + network input + """ + if "openvino" in model_type: + from openvino.runtime import Core + + keep_aspect_ratio = not square + + if model_type == "dpt_beit_large_512": + model = DPTDepthModel( + path=model_path, + backbone="beitl16_512", + non_negative=True, + ) + net_w, net_h = 512, 512 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_beit_large_384": + model = DPTDepthModel( + path=model_path, + backbone="beitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_beit_base_384": + model = DPTDepthModel( + path=model_path, + backbone="beitb16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_large_384": + model = DPTDepthModel( + path=model_path, + backbone="swin2l24_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_base_384": + model = DPTDepthModel( + path=model_path, + backbone="swin2b24_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_tiny_256": + model = DPTDepthModel( + path=model_path, + backbone="swin2t16_256", + non_negative=True, + ) + net_w, net_h = 256, 256 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin_large_384": + model = DPTDepthModel( + path=model_path, + backbone="swinl12_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_next_vit_large_384": + model = DPTDepthModel( + path=model_path, + backbone="next_vit_large_6m", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + # We change the notation from dpt_levit_224 (MiDaS notation) to levit_384 (timm notation) here, where the 224 refers + # to the resolution 224x224 used by LeViT and 384 is the first entry of the embed_dim, see _cfg and model_cfgs of + # https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/levit.py + # (commit id: 927f031293a30afb940fff0bee34b85d9c059b0e) + elif model_type == "dpt_levit_224": + model = DPTDepthModel( + path=model_path, + backbone="levit_384", + non_negative=True, + head_features_1=64, + head_features_2=8, + ) + net_w, net_h = 224, 224 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_large_384": + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid_384": + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21_384": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small_256": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "openvino_midas_v21_small_256": + ie = Core() + uncompiled_model = ie.read_model(model=model_path) + model = ie.compile_model(uncompiled_model, "CPU") + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + if not "openvino" in model_type: + print("Model loaded, number of parameters = {:.0f}M".format(sum(p.numel() for p in model.parameters()) / 1e6)) + else: + print("Model loaded, optimized with OpenVINO") + + if "openvino" in model_type: + keep_aspect_ratio = False + + if height is not None: + net_w, net_h = height, height + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + if not "openvino" in model_type: + model.eval() + + if optimize and (device == torch.device("cuda")): + if not "openvino" in model_type: + model = model.to(memory_format=torch.channels_last) + model = model.half() + else: + print("Error: OpenVINO models are already optimized. No optimization to half-float possible.") + exit() + + if not "openvino" in model_type: + model.to(device) + + return model, transform, net_w, net_h diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/README.md b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/README.md new file mode 100644 index 0000000000000000000000000000000000000000..45c18f7f0bfe40c0db373e8a94716867705f5827 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/README.md @@ -0,0 +1,70 @@ +## Mobile version of MiDaS for iOS / Android - Monocular Depth Estimation + +### Accuracy + +* Old small model - ResNet50 default-decoder 384x384 +* New small model - EfficientNet-Lite3 small-decoder 256x256 + +**Zero-shot error** (the lower - the better): + +| Model | DIW WHDR | Eth3d AbsRel | Sintel AbsRel | Kitti δ>1.25 | NyuDepthV2 δ>1.25 | TUM δ>1.25 | +|---|---|---|---|---|---|---| +| Old small model 384x384 | **0.1248** | 0.1550 | **0.3300** | **21.81** | 15.73 | 17.00 | +| New small model 256x256 | 0.1344 | **0.1344** | 0.3370 | 29.27 | **13.43** | **14.53** | +| Relative improvement, % | -8 % | **+13 %** | -2 % | -34 % | **+15 %** | **+15 %** | + +None of Train/Valid/Test subsets of datasets (DIW, Eth3d, Sintel, Kitti, NyuDepthV2, TUM) were not involved in Training or Fine Tuning. + +### Inference speed (FPS) on iOS / Android + +**Frames Per Second** (the higher - the better): + +| Model | iPhone CPU | iPhone GPU | iPhone NPU | OnePlus8 CPU | OnePlus8 GPU | OnePlus8 NNAPI | +|---|---|---|---|---|---|---| +| Old small model 384x384 | 0.6 | N/A | N/A | 0.45 | 0.50 | 0.50 | +| New small model 256x256 | 8 | 22 | **30** | 6 | **22** | 4 | +| SpeedUp, X times | **12.8x** | - | - | **13.2x** | **44x** | **8x** | + +N/A - run-time error (no data available) + + +#### Models: + +* Old small model - ResNet50 default-decoder 1x384x384x3, batch=1 FP32 (converters: Pytorch -> ONNX - [onnx_tf](https://github.com/onnx/onnx-tensorflow) -> (saved model) PB -> TFlite) + + (Trained on datasets: RedWeb, MegaDepth, WSVD, 3D Movies, DIML indoor) + +* New small model - EfficientNet-Lite3 small-decoder 1x256x256x3, batch=1 FP32 (custom converter: Pytorch -> TFlite) + + (Trained on datasets: RedWeb, MegaDepth, WSVD, 3D Movies, DIML indoor, HRWSI, IRS, TartanAir, BlendedMVS, ApolloScape) + +#### Frameworks for training and conversions: +``` +pip install torch==1.6.0 torchvision==0.7.0 +pip install tf-nightly-gpu==2.5.0.dev20201031 tensorflow-addons==0.11.2 numpy==1.18.0 +git clone --depth 1 --branch v1.6.0 https://github.com/onnx/onnx-tensorflow +``` + +#### SoC - OS - Library: + +* iPhone 11 (A13 Bionic) - iOS 13.7 - TensorFlowLiteSwift 0.0.1-nightly +* OnePlus 8 (Snapdragon 865) - Andoird 10 - org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly + + +### Citation + +This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3): + +>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + +Please cite our paper if you use this code or any of the models: +``` +@article{Ranftl2020, + author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun}, + title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer}, + journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, + year = {2020}, +} +``` + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/.gitignore b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..2fbe357549c64ae2966d5c3013a9179427b7b396 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/.gitignore @@ -0,0 +1,13 @@ +*.iml +.gradle +/local.properties +/.idea/libraries +/.idea/modules.xml +/.idea/workspace.xml +.DS_Store +/build +/captures +.externalNativeBuild + +/.gradle/ +/.idea/ \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/EXPLORE_THE_CODE.md b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/EXPLORE_THE_CODE.md new file mode 100644 index 0000000000000000000000000000000000000000..72014bdfa2cd701a6453debbc8e53fcc15c0a5dc --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/EXPLORE_THE_CODE.md @@ -0,0 +1,414 @@ +# TensorFlow Lite Android image classification example + +This document walks through the code of a simple Android mobile application that +demonstrates +[image classification](https://www.tensorflow.org/lite/models/image_classification/overview) +using the device camera. + +## Explore the code + +We're now going to walk through the most important parts of the sample code. + +### Get camera input + +This mobile application gets the camera input using the functions defined in the +file +[`CameraActivity.java`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java). +This file depends on +[`AndroidManifest.xml`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/app/src/main/AndroidManifest.xml) +to set the camera orientation. + +`CameraActivity` also contains code to capture user preferences from the UI and +make them available to other classes via convenience methods. + +```java +model = Model.valueOf(modelSpinner.getSelectedItem().toString().toUpperCase()); +device = Device.valueOf(deviceSpinner.getSelectedItem().toString()); +numThreads = Integer.parseInt(threadsTextView.getText().toString().trim()); +``` + +### Classifier + +This Image Classification Android reference app demonstrates two implementation +solutions, +[`lib_task_api`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_task_api) +that leverages the out-of-box API from the +[TensorFlow Lite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier), +and +[`lib_support`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_support) +that creates the custom inference pipleline using the +[TensorFlow Lite Support Library](https://www.tensorflow.org/lite/inference_with_metadata/lite_support). + +Both solutions implement the file `Classifier.java` (see +[the one in lib_task_api](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java) +and +[the one in lib_support](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java)) +that contains most of the complex logic for processing the camera input and +running inference. + +Two subclasses of the `Classifier` exist, as in `ClassifierFloatMobileNet.java` +and `ClassifierQuantizedMobileNet.java`, which contain settings for both +floating point and +[quantized](https://www.tensorflow.org/lite/performance/post_training_quantization) +models. + +The `Classifier` class implements a static method, `create`, which is used to +instantiate the appropriate subclass based on the supplied model type (quantized +vs floating point). + +#### Using the TensorFlow Lite Task Library + +Inference can be done using just a few lines of code with the +[`ImageClassifier`](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier) +in the TensorFlow Lite Task Library. + +##### Load model and create ImageClassifier + +`ImageClassifier` expects a model populated with the +[model metadata](https://www.tensorflow.org/lite/convert/metadata) and the label +file. See the +[model compatibility requirements](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier#model_compatibility_requirements) +for more details. + +`ImageClassifierOptions` allows manipulation on various inference options, such +as setting the maximum number of top scored results to return using +`setMaxResults(MAX_RESULTS)`, and setting the score threshold using +`setScoreThreshold(scoreThreshold)`. + +```java +// Create the ImageClassifier instance. +ImageClassifierOptions options = + ImageClassifierOptions.builder().setMaxResults(MAX_RESULTS).build(); +imageClassifier = ImageClassifier.createFromFileAndOptions(activity, + getModelPath(), options); +``` + +`ImageClassifier` currently does not support configuring delegates and +multithread, but those are on our roadmap. Please stay tuned! + +##### Run inference + +`ImageClassifier` contains builtin logic to preprocess the input image, such as +rotating and resizing an image. Processing options can be configured through +`ImageProcessingOptions`. In the following example, input images are rotated to +the up-right angle and cropped to the center as the model expects a square input +(`224x224`). See the +[Java doc of `ImageClassifier`](https://github.com/tensorflow/tflite-support/blob/195b574f0aa9856c618b3f1ad87bd185cddeb657/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java#L22) +for more details about how the underlying image processing is performed. + +```java +TensorImage inputImage = TensorImage.fromBitmap(bitmap); +int width = bitmap.getWidth(); +int height = bitmap.getHeight(); +int cropSize = min(width, height); +ImageProcessingOptions imageOptions = + ImageProcessingOptions.builder() + .setOrientation(getOrientation(sensorOrientation)) + // Set the ROI to the center of the image. + .setRoi( + new Rect( + /*left=*/ (width - cropSize) / 2, + /*top=*/ (height - cropSize) / 2, + /*right=*/ (width + cropSize) / 2, + /*bottom=*/ (height + cropSize) / 2)) + .build(); + +List results = imageClassifier.classify(inputImage, + imageOptions); +``` + +The output of `ImageClassifier` is a list of `Classifications` instance, where +each `Classifications` element is a single head classification result. All the +demo models are single head models, therefore, `results` only contains one +`Classifications` object. Use `Classifications.getCategories()` to get a list of +top-k categories as specified with `MAX_RESULTS`. Each `Category` object +contains the srting label and the score of that category. + +To match the implementation of +[`lib_support`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_support), +`results` is converted into `List` in the method, +`getRecognitions`. + +#### Using the TensorFlow Lite Support Library + +##### Load model and create interpreter + +To perform inference, we need to load a model file and instantiate an +`Interpreter`. This happens in the constructor of the `Classifier` class, along +with loading the list of class labels. Information about the device type and +number of threads is used to configure the `Interpreter` via the +`Interpreter.Options` instance passed into its constructor. Note that if a GPU, +DSP (Digital Signal Processor) or NPU (Neural Processing Unit) is available, a +[`Delegate`](https://www.tensorflow.org/lite/performance/delegates) can be used +to take full advantage of these hardware. + +Please note that there are performance edge cases and developers are adviced to +test with a representative set of devices prior to production. + +```java +protected Classifier(Activity activity, Device device, int numThreads) throws + IOException { + tfliteModel = FileUtil.loadMappedFile(activity, getModelPath()); + switch (device) { + case NNAPI: + nnApiDelegate = new NnApiDelegate(); + tfliteOptions.addDelegate(nnApiDelegate); + break; + case GPU: + gpuDelegate = new GpuDelegate(); + tfliteOptions.addDelegate(gpuDelegate); + break; + case CPU: + break; + } + tfliteOptions.setNumThreads(numThreads); + tflite = new Interpreter(tfliteModel, tfliteOptions); + labels = FileUtil.loadLabels(activity, getLabelPath()); +... +``` + +For Android devices, we recommend pre-loading and memory mapping the model file +to offer faster load times and reduce the dirty pages in memory. The method +`FileUtil.loadMappedFile` does this, returning a `MappedByteBuffer` containing +the model. + +The `MappedByteBuffer` is passed into the `Interpreter` constructor, along with +an `Interpreter.Options` object. This object can be used to configure the +interpreter, for example by setting the number of threads (`.setNumThreads(1)`) +or enabling [NNAPI](https://developer.android.com/ndk/guides/neuralnetworks) +(`.addDelegate(nnApiDelegate)`). + +##### Pre-process bitmap image + +Next in the `Classifier` constructor, we take the input camera bitmap image, +convert it to a `TensorImage` format for efficient processing and pre-process +it. The steps are shown in the private 'loadImage' method: + +```java +/** Loads input image, and applys preprocessing. */ +private TensorImage loadImage(final Bitmap bitmap, int sensorOrientation) { + // Loads bitmap into a TensorImage. + image.load(bitmap); + + // Creates processor for the TensorImage. + int cropSize = Math.min(bitmap.getWidth(), bitmap.getHeight()); + int numRoration = sensorOrientation / 90; + ImageProcessor imageProcessor = + new ImageProcessor.Builder() + .add(new ResizeWithCropOrPadOp(cropSize, cropSize)) + .add(new ResizeOp(imageSizeX, imageSizeY, ResizeMethod.BILINEAR)) + .add(new Rot90Op(numRoration)) + .add(getPreprocessNormalizeOp()) + .build(); + return imageProcessor.process(inputImageBuffer); +} +``` + +The pre-processing is largely the same for quantized and float models with one +exception: Normalization. + +In `ClassifierFloatMobileNet`, the normalization parameters are defined as: + +```java +private static final float IMAGE_MEAN = 127.5f; +private static final float IMAGE_STD = 127.5f; +``` + +In `ClassifierQuantizedMobileNet`, normalization is not required. Thus the +nomalization parameters are defined as: + +```java +private static final float IMAGE_MEAN = 0.0f; +private static final float IMAGE_STD = 1.0f; +``` + +##### Allocate output object + +Initiate the output `TensorBuffer` for the output of the model. + +```java +/** Output probability TensorBuffer. */ +private final TensorBuffer outputProbabilityBuffer; + +//... +// Get the array size for the output buffer from the TensorFlow Lite model file +int probabilityTensorIndex = 0; +int[] probabilityShape = + tflite.getOutputTensor(probabilityTensorIndex).shape(); // {1, 1001} +DataType probabilityDataType = + tflite.getOutputTensor(probabilityTensorIndex).dataType(); + +// Creates the output tensor and its processor. +outputProbabilityBuffer = + TensorBuffer.createFixedSize(probabilityShape, probabilityDataType); + +// Creates the post processor for the output probability. +probabilityProcessor = + new TensorProcessor.Builder().add(getPostprocessNormalizeOp()).build(); +``` + +For quantized models, we need to de-quantize the prediction with the NormalizeOp +(as they are all essentially linear transformation). For float model, +de-quantize is not required. But to uniform the API, de-quantize is added to +float model too. Mean and std are set to 0.0f and 1.0f, respectively. To be more +specific, + +In `ClassifierQuantizedMobileNet`, the normalized parameters are defined as: + +```java +private static final float PROBABILITY_MEAN = 0.0f; +private static final float PROBABILITY_STD = 255.0f; +``` + +In `ClassifierFloatMobileNet`, the normalized parameters are defined as: + +```java +private static final float PROBABILITY_MEAN = 0.0f; +private static final float PROBABILITY_STD = 1.0f; +``` + +##### Run inference + +Inference is performed using the following in `Classifier` class: + +```java +tflite.run(inputImageBuffer.getBuffer(), + outputProbabilityBuffer.getBuffer().rewind()); +``` + +##### Recognize image + +Rather than call `run` directly, the method `recognizeImage` is used. It accepts +a bitmap and sensor orientation, runs inference, and returns a sorted `List` of +`Recognition` instances, each corresponding to a label. The method will return a +number of results bounded by `MAX_RESULTS`, which is 3 by default. + +`Recognition` is a simple class that contains information about a specific +recognition result, including its `title` and `confidence`. Using the +post-processing normalization method specified, the confidence is converted to +between 0 and 1 of a given class being represented by the image. + +```java +/** Gets the label to probability map. */ +Map labeledProbability = + new TensorLabel(labels, + probabilityProcessor.process(outputProbabilityBuffer)) + .getMapWithFloatValue(); +``` + +A `PriorityQueue` is used for sorting. + +```java +/** Gets the top-k results. */ +private static List getTopKProbability( + Map labelProb) { + // Find the best classifications. + PriorityQueue pq = + new PriorityQueue<>( + MAX_RESULTS, + new Comparator() { + @Override + public int compare(Recognition lhs, Recognition rhs) { + // Intentionally reversed to put high confidence at the head of + // the queue. + return Float.compare(rhs.getConfidence(), lhs.getConfidence()); + } + }); + + for (Map.Entry entry : labelProb.entrySet()) { + pq.add(new Recognition("" + entry.getKey(), entry.getKey(), + entry.getValue(), null)); + } + + final ArrayList recognitions = new ArrayList<>(); + int recognitionsSize = Math.min(pq.size(), MAX_RESULTS); + for (int i = 0; i < recognitionsSize; ++i) { + recognitions.add(pq.poll()); + } + return recognitions; +} +``` + +### Display results + +The classifier is invoked and inference results are displayed by the +`processImage()` function in +[`ClassifierActivity.java`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java). + +`ClassifierActivity` is a subclass of `CameraActivity` that contains method +implementations that render the camera image, run classification, and display +the results. The method `processImage()` runs classification on a background +thread as fast as possible, rendering information on the UI thread to avoid +blocking inference and creating latency. + +```java +@Override +protected void processImage() { + rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, + previewHeight); + final int imageSizeX = classifier.getImageSizeX(); + final int imageSizeY = classifier.getImageSizeY(); + + runInBackground( + new Runnable() { + @Override + public void run() { + if (classifier != null) { + final long startTime = SystemClock.uptimeMillis(); + final List results = + classifier.recognizeImage(rgbFrameBitmap, sensorOrientation); + lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; + LOGGER.v("Detect: %s", results); + + runOnUiThread( + new Runnable() { + @Override + public void run() { + showResultsInBottomSheet(results); + showFrameInfo(previewWidth + "x" + previewHeight); + showCropInfo(imageSizeX + "x" + imageSizeY); + showCameraResolution(imageSizeX + "x" + imageSizeY); + showRotationInfo(String.valueOf(sensorOrientation)); + showInference(lastProcessingTimeMs + "ms"); + } + }); + } + readyForNextImage(); + } + }); +} +``` + +Another important role of `ClassifierActivity` is to determine user preferences +(by interrogating `CameraActivity`), and instantiate the appropriately +configured `Classifier` subclass. This happens when the video feed begins (via +`onPreviewSizeChosen()`) and when options are changed in the UI (via +`onInferenceConfigurationChanged()`). + +```java +private void recreateClassifier(Model model, Device device, int numThreads) { + if (classifier != null) { + LOGGER.d("Closing classifier."); + classifier.close(); + classifier = null; + } + if (device == Device.GPU && model == Model.QUANTIZED) { + LOGGER.d("Not creating classifier: GPU doesn't support quantized models."); + runOnUiThread( + () -> { + Toast.makeText(this, "GPU does not yet supported quantized models.", + Toast.LENGTH_LONG) + .show(); + }); + return; + } + try { + LOGGER.d( + "Creating classifier (model=%s, device=%s, numThreads=%d)", model, + device, numThreads); + classifier = Classifier.create(this, model, device, numThreads); + } catch (IOException e) { + LOGGER.e(e, "Failed to create classifier."); + } +} +``` diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/LICENSE b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6606ec028d1c629986e7019fe3564f5b4bfe425d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Alexey + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/README.md b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/README.md new file mode 100644 index 0000000000000000000000000000000000000000..faf415eb27ccc1a62357718d1e0a9b8c746de4e8 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/README.md @@ -0,0 +1,21 @@ +# MiDaS on Android smartphone by using TensorFlow-lite (TFLite) + + +* Either use Android Studio for compilation. + +* Or use ready to install apk-file: + * Or use URL: https://i.diawi.com/CVb8a9 + * Or use QR-code: + +Scan QR-code or open URL -> Press `Install application` -> Press `Download` and wait for download -> Open -> Install -> Open -> Press: Allow MiDaS to take photo and video from the camera While using the APP + +![CVb8a9](https://user-images.githubusercontent.com/4096485/97727213-38552500-1ae1-11eb-8b76-4ea11216f76d.png) + +---- + +To use another model, you should convert it to `model_opt.tflite` and place it to the directory: `models\src\main\assets` + + +---- + +Original repository: https://github.com/isl-org/MiDaS diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/.gitignore b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1ae74c6780c277d75fedfb7511ff51f69941b48b --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/.gitignore @@ -0,0 +1,3 @@ +/build + +/build/ \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/build.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..94e9886a55c7d54f71b424bb246c849dd6bd795d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/build.gradle @@ -0,0 +1,56 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 28 + defaultConfig { + applicationId "org.tensorflow.lite.examples.classification" + minSdkVersion 21 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "tflite" + } + compileOptions { + sourceCompatibility = '1.8' + targetCompatibility = '1.8' + } + lintOptions { + abortOnError false + } + flavorDimensions "tfliteInference" + productFlavors { + // The TFLite inference is built using the TFLite Support library. + support { + dimension "tfliteInference" + } + // The TFLite inference is built using the TFLite Task library. + taskApi { + dimension "tfliteInference" + } + } + +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar']) + supportImplementation project(":lib_support") + taskApiImplementation project(":lib_task_api") + implementation 'androidx.appcompat:appcompat:1.0.0' + implementation 'androidx.coordinatorlayout:coordinatorlayout:1.0.0' + implementation 'com.google.android.material:material:1.0.0' + + androidTestImplementation 'androidx.test.ext:junit:1.1.1' + androidTestImplementation 'com.google.truth:truth:1.0.1' + androidTestImplementation 'androidx.test:runner:1.2.0' + androidTestImplementation 'androidx.test:rules:1.1.0' +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/proguard-rules.pro b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/proguard-rules.pro new file mode 100644 index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/AndroidManifest.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..3653d8799092492ebbb16c7c956eb50e3d404aa4 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/AndroidManifest.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/org/tensorflow/lite/examples/classification/ClassifierTest.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/org/tensorflow/lite/examples/classification/ClassifierTest.java new file mode 100644 index 0000000000000000000000000000000000000000..0194132890aae659c2a70d33106306ed665b22e8 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/org/tensorflow/lite/examples/classification/ClassifierTest.java @@ -0,0 +1,121 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification; + +import static com.google.common.truth.Truth.assertThat; + +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.util.Log; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import androidx.test.platform.app.InstrumentationRegistry; +import androidx.test.rule.ActivityTestRule; +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Scanner; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.tensorflow.lite.examples.classification.tflite.Classifier; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Model; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; + +/** Golden test for Image Classification Reference app. */ +@RunWith(AndroidJUnit4.class) +public class ClassifierTest { + + @Rule + public ActivityTestRule rule = + new ActivityTestRule<>(ClassifierActivity.class); + + private static final String[] INPUTS = {"fox.jpg"}; + private static final String[] GOLDEN_OUTPUTS_SUPPORT = {"fox-mobilenet_v1_1.0_224_support.txt"}; + private static final String[] GOLDEN_OUTPUTS_TASK = {"fox-mobilenet_v1_1.0_224_task_api.txt"}; + + @Test + public void classificationResultsShouldNotChange() throws IOException { + ClassifierActivity activity = rule.getActivity(); + Classifier classifier = Classifier.create(activity, Model.FLOAT_MOBILENET, Device.CPU, 1); + for (int i = 0; i < INPUTS.length; i++) { + String imageFileName = INPUTS[i]; + String goldenOutputFileName; + // TODO(b/169379396): investigate the impact of the resize algorithm on accuracy. + // This is a temporary workaround to set different golden rest results as the preprocessing + // of lib_support and lib_task_api are different. Will merge them once the above TODO is + // resolved. + if (Classifier.TAG.equals("ClassifierWithSupport")) { + goldenOutputFileName = GOLDEN_OUTPUTS_SUPPORT[i]; + } else { + goldenOutputFileName = GOLDEN_OUTPUTS_TASK[i]; + } + Bitmap input = loadImage(imageFileName); + List goldenOutput = loadRecognitions(goldenOutputFileName); + + List result = classifier.recognizeImage(input, 0); + Iterator goldenOutputIterator = goldenOutput.iterator(); + + for (Recognition actual : result) { + Assert.assertTrue(goldenOutputIterator.hasNext()); + Recognition expected = goldenOutputIterator.next(); + assertThat(actual.getTitle()).isEqualTo(expected.getTitle()); + assertThat(actual.getConfidence()).isWithin(0.01f).of(expected.getConfidence()); + } + } + } + + private static Bitmap loadImage(String fileName) { + AssetManager assetManager = + InstrumentationRegistry.getInstrumentation().getContext().getAssets(); + InputStream inputStream = null; + try { + inputStream = assetManager.open(fileName); + } catch (IOException e) { + Log.e("Test", "Cannot load image from assets"); + } + return BitmapFactory.decodeStream(inputStream); + } + + private static List loadRecognitions(String fileName) { + AssetManager assetManager = + InstrumentationRegistry.getInstrumentation().getContext().getAssets(); + InputStream inputStream = null; + try { + inputStream = assetManager.open(fileName); + } catch (IOException e) { + Log.e("Test", "Cannot load probability results from assets"); + } + Scanner scanner = new Scanner(inputStream); + List result = new ArrayList<>(); + while (scanner.hasNext()) { + String category = scanner.next(); + category = category.replace('_', ' '); + if (!scanner.hasNextFloat()) { + break; + } + float probability = scanner.nextFloat(); + Recognition recognition = new Recognition(null, category, probability, null); + result.add(recognition); + } + return result; + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/AndroidManifest.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..7a414d5176a117262dce56c2220e6b71791287de --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/AndroidManifest.xml @@ -0,0 +1,28 @@ + + + + + + + + + + + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..d1eb26c862c04bf573ecc4eb127e7460f0b100fc --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java @@ -0,0 +1,717 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification; + +import android.Manifest; +import android.app.Fragment; +import android.content.Context; +import android.content.pm.PackageManager; +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Paint; +import android.graphics.RectF; +import android.hardware.Camera; +import android.hardware.camera2.CameraAccessException; +import android.hardware.camera2.CameraCharacteristics; +import android.hardware.camera2.CameraManager; +import android.hardware.camera2.params.StreamConfigurationMap; +import android.media.Image; +import android.media.Image.Plane; +import android.media.ImageReader; +import android.media.ImageReader.OnImageAvailableListener; +import android.os.Build; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.os.Trace; +import androidx.annotation.NonNull; +import androidx.annotation.UiThread; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Size; +import android.view.Surface; +import android.view.TextureView; +import android.view.View; +import android.view.ViewTreeObserver; +import android.view.WindowManager; +import android.widget.AdapterView; +import android.widget.ImageView; +import android.widget.LinearLayout; +import android.widget.Spinner; +import android.widget.TextView; +import android.widget.Toast; +import com.google.android.material.bottomsheet.BottomSheetBehavior; +import java.nio.ByteBuffer; +import java.util.List; +import org.tensorflow.lite.examples.classification.env.ImageUtils; +import org.tensorflow.lite.examples.classification.env.Logger; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Model; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; + +public abstract class CameraActivity extends AppCompatActivity + implements OnImageAvailableListener, + Camera.PreviewCallback, + View.OnClickListener, + AdapterView.OnItemSelectedListener { + private static final Logger LOGGER = new Logger(); + + private static final int PERMISSIONS_REQUEST = 1; + + private static final String PERMISSION_CAMERA = Manifest.permission.CAMERA; + protected int previewWidth = 0; + protected int previewHeight = 0; + private Handler handler; + private HandlerThread handlerThread; + private boolean useCamera2API; + private boolean isProcessingFrame = false; + private byte[][] yuvBytes = new byte[3][]; + private int[] rgbBytes = null; + private int yRowStride; + private Runnable postInferenceCallback; + private Runnable imageConverter; + private LinearLayout bottomSheetLayout; + private LinearLayout gestureLayout; + private BottomSheetBehavior sheetBehavior; + protected TextView recognitionTextView, + recognition1TextView, + recognition2TextView, + recognitionValueTextView, + recognition1ValueTextView, + recognition2ValueTextView; + protected TextView frameValueTextView, + cropValueTextView, + cameraResolutionTextView, + rotationTextView, + inferenceTimeTextView; + protected ImageView bottomSheetArrowImageView; + private ImageView plusImageView, minusImageView; + private Spinner modelSpinner; + private Spinner deviceSpinner; + private TextView threadsTextView; + + //private Model model = Model.QUANTIZED_EFFICIENTNET; + //private Device device = Device.CPU; + private Model model = Model.FLOAT_EFFICIENTNET; + private Device device = Device.GPU; + private int numThreads = -1; + + @Override + protected void onCreate(final Bundle savedInstanceState) { + LOGGER.d("onCreate " + this); + super.onCreate(null); + getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON); + + setContentView(R.layout.tfe_ic_activity_camera); + + if (hasPermission()) { + setFragment(); + } else { + requestPermission(); + } + + threadsTextView = findViewById(R.id.threads); + plusImageView = findViewById(R.id.plus); + minusImageView = findViewById(R.id.minus); + modelSpinner = findViewById(R.id.model_spinner); + deviceSpinner = findViewById(R.id.device_spinner); + bottomSheetLayout = findViewById(R.id.bottom_sheet_layout); + gestureLayout = findViewById(R.id.gesture_layout); + sheetBehavior = BottomSheetBehavior.from(bottomSheetLayout); + bottomSheetArrowImageView = findViewById(R.id.bottom_sheet_arrow); + + ViewTreeObserver vto = gestureLayout.getViewTreeObserver(); + vto.addOnGlobalLayoutListener( + new ViewTreeObserver.OnGlobalLayoutListener() { + @Override + public void onGlobalLayout() { + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.JELLY_BEAN) { + gestureLayout.getViewTreeObserver().removeGlobalOnLayoutListener(this); + } else { + gestureLayout.getViewTreeObserver().removeOnGlobalLayoutListener(this); + } + // int width = bottomSheetLayout.getMeasuredWidth(); + int height = gestureLayout.getMeasuredHeight(); + + sheetBehavior.setPeekHeight(height); + } + }); + sheetBehavior.setHideable(false); + + sheetBehavior.setBottomSheetCallback( + new BottomSheetBehavior.BottomSheetCallback() { + @Override + public void onStateChanged(@NonNull View bottomSheet, int newState) { + switch (newState) { + case BottomSheetBehavior.STATE_HIDDEN: + break; + case BottomSheetBehavior.STATE_EXPANDED: + { + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_down); + } + break; + case BottomSheetBehavior.STATE_COLLAPSED: + { + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_up); + } + break; + case BottomSheetBehavior.STATE_DRAGGING: + break; + case BottomSheetBehavior.STATE_SETTLING: + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_up); + break; + } + } + + @Override + public void onSlide(@NonNull View bottomSheet, float slideOffset) {} + }); + + recognitionTextView = findViewById(R.id.detected_item); + recognitionValueTextView = findViewById(R.id.detected_item_value); + recognition1TextView = findViewById(R.id.detected_item1); + recognition1ValueTextView = findViewById(R.id.detected_item1_value); + recognition2TextView = findViewById(R.id.detected_item2); + recognition2ValueTextView = findViewById(R.id.detected_item2_value); + + frameValueTextView = findViewById(R.id.frame_info); + cropValueTextView = findViewById(R.id.crop_info); + cameraResolutionTextView = findViewById(R.id.view_info); + rotationTextView = findViewById(R.id.rotation_info); + inferenceTimeTextView = findViewById(R.id.inference_info); + + modelSpinner.setOnItemSelectedListener(this); + deviceSpinner.setOnItemSelectedListener(this); + + plusImageView.setOnClickListener(this); + minusImageView.setOnClickListener(this); + + model = Model.valueOf(modelSpinner.getSelectedItem().toString().toUpperCase()); + device = Device.valueOf(deviceSpinner.getSelectedItem().toString()); + numThreads = Integer.parseInt(threadsTextView.getText().toString().trim()); + } + + protected int[] getRgbBytes() { + imageConverter.run(); + return rgbBytes; + } + + protected int getLuminanceStride() { + return yRowStride; + } + + protected byte[] getLuminance() { + return yuvBytes[0]; + } + + /** Callback for android.hardware.Camera API */ + @Override + public void onPreviewFrame(final byte[] bytes, final Camera camera) { + if (isProcessingFrame) { + LOGGER.w("Dropping frame!"); + return; + } + + try { + // Initialize the storage bitmaps once when the resolution is known. + if (rgbBytes == null) { + Camera.Size previewSize = camera.getParameters().getPreviewSize(); + previewHeight = previewSize.height; + previewWidth = previewSize.width; + rgbBytes = new int[previewWidth * previewHeight]; + onPreviewSizeChosen(new Size(previewSize.width, previewSize.height), 90); + } + } catch (final Exception e) { + LOGGER.e(e, "Exception!"); + return; + } + + isProcessingFrame = true; + yuvBytes[0] = bytes; + yRowStride = previewWidth; + + imageConverter = + new Runnable() { + @Override + public void run() { + ImageUtils.convertYUV420SPToARGB8888(bytes, previewWidth, previewHeight, rgbBytes); + } + }; + + postInferenceCallback = + new Runnable() { + @Override + public void run() { + camera.addCallbackBuffer(bytes); + isProcessingFrame = false; + } + }; + processImage(); + } + + /** Callback for Camera2 API */ + @Override + public void onImageAvailable(final ImageReader reader) { + // We need wait until we have some size from onPreviewSizeChosen + if (previewWidth == 0 || previewHeight == 0) { + return; + } + if (rgbBytes == null) { + rgbBytes = new int[previewWidth * previewHeight]; + } + try { + final Image image = reader.acquireLatestImage(); + + if (image == null) { + return; + } + + if (isProcessingFrame) { + image.close(); + return; + } + isProcessingFrame = true; + Trace.beginSection("imageAvailable"); + final Plane[] planes = image.getPlanes(); + fillBytes(planes, yuvBytes); + yRowStride = planes[0].getRowStride(); + final int uvRowStride = planes[1].getRowStride(); + final int uvPixelStride = planes[1].getPixelStride(); + + imageConverter = + new Runnable() { + @Override + public void run() { + ImageUtils.convertYUV420ToARGB8888( + yuvBytes[0], + yuvBytes[1], + yuvBytes[2], + previewWidth, + previewHeight, + yRowStride, + uvRowStride, + uvPixelStride, + rgbBytes); + } + }; + + postInferenceCallback = + new Runnable() { + @Override + public void run() { + image.close(); + isProcessingFrame = false; + } + }; + + processImage(); + } catch (final Exception e) { + LOGGER.e(e, "Exception!"); + Trace.endSection(); + return; + } + Trace.endSection(); + } + + @Override + public synchronized void onStart() { + LOGGER.d("onStart " + this); + super.onStart(); + } + + @Override + public synchronized void onResume() { + LOGGER.d("onResume " + this); + super.onResume(); + + handlerThread = new HandlerThread("inference"); + handlerThread.start(); + handler = new Handler(handlerThread.getLooper()); + } + + @Override + public synchronized void onPause() { + LOGGER.d("onPause " + this); + + handlerThread.quitSafely(); + try { + handlerThread.join(); + handlerThread = null; + handler = null; + } catch (final InterruptedException e) { + LOGGER.e(e, "Exception!"); + } + + super.onPause(); + } + + @Override + public synchronized void onStop() { + LOGGER.d("onStop " + this); + super.onStop(); + } + + @Override + public synchronized void onDestroy() { + LOGGER.d("onDestroy " + this); + super.onDestroy(); + } + + protected synchronized void runInBackground(final Runnable r) { + if (handler != null) { + handler.post(r); + } + } + + @Override + public void onRequestPermissionsResult( + final int requestCode, final String[] permissions, final int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + if (requestCode == PERMISSIONS_REQUEST) { + if (allPermissionsGranted(grantResults)) { + setFragment(); + } else { + requestPermission(); + } + } + } + + private static boolean allPermissionsGranted(final int[] grantResults) { + for (int result : grantResults) { + if (result != PackageManager.PERMISSION_GRANTED) { + return false; + } + } + return true; + } + + private boolean hasPermission() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { + return checkSelfPermission(PERMISSION_CAMERA) == PackageManager.PERMISSION_GRANTED; + } else { + return true; + } + } + + private void requestPermission() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { + if (shouldShowRequestPermissionRationale(PERMISSION_CAMERA)) { + Toast.makeText( + CameraActivity.this, + "Camera permission is required for this demo", + Toast.LENGTH_LONG) + .show(); + } + requestPermissions(new String[] {PERMISSION_CAMERA}, PERMISSIONS_REQUEST); + } + } + + // Returns true if the device supports the required hardware level, or better. + private boolean isHardwareLevelSupported( + CameraCharacteristics characteristics, int requiredLevel) { + int deviceLevel = characteristics.get(CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL); + if (deviceLevel == CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_LEGACY) { + return requiredLevel == deviceLevel; + } + // deviceLevel is not LEGACY, can use numerical sort + return requiredLevel <= deviceLevel; + } + + private String chooseCamera() { + final CameraManager manager = (CameraManager) getSystemService(Context.CAMERA_SERVICE); + try { + for (final String cameraId : manager.getCameraIdList()) { + final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId); + + // We don't use a front facing camera in this sample. + final Integer facing = characteristics.get(CameraCharacteristics.LENS_FACING); + if (facing != null && facing == CameraCharacteristics.LENS_FACING_FRONT) { + continue; + } + + final StreamConfigurationMap map = + characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP); + + if (map == null) { + continue; + } + + // Fallback to camera1 API for internal cameras that don't have full support. + // This should help with legacy situations where using the camera2 API causes + // distorted or otherwise broken previews. + useCamera2API = + (facing == CameraCharacteristics.LENS_FACING_EXTERNAL) + || isHardwareLevelSupported( + characteristics, CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_FULL); + LOGGER.i("Camera API lv2?: %s", useCamera2API); + return cameraId; + } + } catch (CameraAccessException e) { + LOGGER.e(e, "Not allowed to access camera"); + } + + return null; + } + + protected void setFragment() { + String cameraId = chooseCamera(); + + Fragment fragment; + if (useCamera2API) { + CameraConnectionFragment camera2Fragment = + CameraConnectionFragment.newInstance( + new CameraConnectionFragment.ConnectionCallback() { + @Override + public void onPreviewSizeChosen(final Size size, final int rotation) { + previewHeight = size.getHeight(); + previewWidth = size.getWidth(); + CameraActivity.this.onPreviewSizeChosen(size, rotation); + } + }, + this, + getLayoutId(), + getDesiredPreviewFrameSize()); + + camera2Fragment.setCamera(cameraId); + fragment = camera2Fragment; + } else { + fragment = + new LegacyCameraConnectionFragment(this, getLayoutId(), getDesiredPreviewFrameSize()); + } + + getFragmentManager().beginTransaction().replace(R.id.container, fragment).commit(); + } + + protected void fillBytes(final Plane[] planes, final byte[][] yuvBytes) { + // Because of the variable row stride it's not possible to know in + // advance the actual necessary dimensions of the yuv planes. + for (int i = 0; i < planes.length; ++i) { + final ByteBuffer buffer = planes[i].getBuffer(); + if (yuvBytes[i] == null) { + LOGGER.d("Initializing buffer %d at size %d", i, buffer.capacity()); + yuvBytes[i] = new byte[buffer.capacity()]; + } + buffer.get(yuvBytes[i]); + } + } + + protected void readyForNextImage() { + if (postInferenceCallback != null) { + postInferenceCallback.run(); + } + } + + protected int getScreenOrientation() { + switch (getWindowManager().getDefaultDisplay().getRotation()) { + case Surface.ROTATION_270: + return 270; + case Surface.ROTATION_180: + return 180; + case Surface.ROTATION_90: + return 90; + default: + return 0; + } + } + + @UiThread + protected void showResultsInTexture(float[] img_array, int imageSizeX, int imageSizeY) { + float maxval = Float.NEGATIVE_INFINITY; + float minval = Float.POSITIVE_INFINITY; + for (float cur : img_array) { + maxval = Math.max(maxval, cur); + minval = Math.min(minval, cur); + } + float multiplier = 0; + if ((maxval - minval) > 0) multiplier = 255 / (maxval - minval); + + int[] img_normalized = new int[img_array.length]; + for (int i = 0; i < img_array.length; ++i) { + float val = (float) (multiplier * (img_array[i] - minval)); + img_normalized[i] = (int) val; + } + + + + TextureView textureView = findViewById(R.id.textureView3); + //AutoFitTextureView textureView = (AutoFitTextureView) findViewById(R.id.texture); + + if(textureView.isAvailable()) { + int width = imageSizeX; + int height = imageSizeY; + + Canvas canvas = textureView.lockCanvas(); + canvas.drawColor(Color.BLUE); + Paint paint = new Paint(); + paint.setStyle(Paint.Style.FILL); + paint.setARGB(255, 150, 150, 150); + + int canvas_size = Math.min(canvas.getWidth(), canvas.getHeight()); + + Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.RGB_565); + + for (int ii = 0; ii < width; ii++) //pass the screen pixels in 2 directions + { + for (int jj = 0; jj < height; jj++) { + //int val = img_normalized[ii + jj * width]; + int index = (width - ii - 1) + (height - jj - 1) * width; + if(index < img_array.length) { + int val = img_normalized[index]; + bitmap.setPixel(ii, jj, Color.rgb(val, val, val)); + } + } + } + + canvas.drawBitmap(bitmap, null, new RectF(0, 0, canvas_size, canvas_size), null); + + textureView.unlockCanvasAndPost(canvas); + + } + + } + + protected void showResultsInBottomSheet(List results) { + if (results != null && results.size() >= 3) { + Recognition recognition = results.get(0); + if (recognition != null) { + if (recognition.getTitle() != null) recognitionTextView.setText(recognition.getTitle()); + if (recognition.getConfidence() != null) + recognitionValueTextView.setText( + String.format("%.2f", (100 * recognition.getConfidence())) + "%"); + } + + Recognition recognition1 = results.get(1); + if (recognition1 != null) { + if (recognition1.getTitle() != null) recognition1TextView.setText(recognition1.getTitle()); + if (recognition1.getConfidence() != null) + recognition1ValueTextView.setText( + String.format("%.2f", (100 * recognition1.getConfidence())) + "%"); + } + + Recognition recognition2 = results.get(2); + if (recognition2 != null) { + if (recognition2.getTitle() != null) recognition2TextView.setText(recognition2.getTitle()); + if (recognition2.getConfidence() != null) + recognition2ValueTextView.setText( + String.format("%.2f", (100 * recognition2.getConfidence())) + "%"); + } + } + } + + protected void showFrameInfo(String frameInfo) { + frameValueTextView.setText(frameInfo); + } + + protected void showCropInfo(String cropInfo) { + cropValueTextView.setText(cropInfo); + } + + protected void showCameraResolution(String cameraInfo) { + cameraResolutionTextView.setText(cameraInfo); + } + + protected void showRotationInfo(String rotation) { + rotationTextView.setText(rotation); + } + + protected void showInference(String inferenceTime) { + inferenceTimeTextView.setText(inferenceTime); + } + + protected Model getModel() { + return model; + } + + private void setModel(Model model) { + if (this.model != model) { + LOGGER.d("Updating model: " + model); + this.model = model; + onInferenceConfigurationChanged(); + } + } + + protected Device getDevice() { + return device; + } + + private void setDevice(Device device) { + if (this.device != device) { + LOGGER.d("Updating device: " + device); + this.device = device; + final boolean threadsEnabled = device == Device.CPU; + plusImageView.setEnabled(threadsEnabled); + minusImageView.setEnabled(threadsEnabled); + threadsTextView.setText(threadsEnabled ? String.valueOf(numThreads) : "N/A"); + onInferenceConfigurationChanged(); + } + } + + protected int getNumThreads() { + return numThreads; + } + + private void setNumThreads(int numThreads) { + if (this.numThreads != numThreads) { + LOGGER.d("Updating numThreads: " + numThreads); + this.numThreads = numThreads; + onInferenceConfigurationChanged(); + } + } + + protected abstract void processImage(); + + protected abstract void onPreviewSizeChosen(final Size size, final int rotation); + + protected abstract int getLayoutId(); + + protected abstract Size getDesiredPreviewFrameSize(); + + protected abstract void onInferenceConfigurationChanged(); + + @Override + public void onClick(View v) { + if (v.getId() == R.id.plus) { + String threads = threadsTextView.getText().toString().trim(); + int numThreads = Integer.parseInt(threads); + if (numThreads >= 9) return; + setNumThreads(++numThreads); + threadsTextView.setText(String.valueOf(numThreads)); + } else if (v.getId() == R.id.minus) { + String threads = threadsTextView.getText().toString().trim(); + int numThreads = Integer.parseInt(threads); + if (numThreads == 1) { + return; + } + setNumThreads(--numThreads); + threadsTextView.setText(String.valueOf(numThreads)); + } + } + + @Override + public void onItemSelected(AdapterView parent, View view, int pos, long id) { + if (parent == modelSpinner) { + setModel(Model.valueOf(parent.getItemAtPosition(pos).toString().toUpperCase())); + } else if (parent == deviceSpinner) { + setDevice(Device.valueOf(parent.getItemAtPosition(pos).toString())); + } + } + + @Override + public void onNothingSelected(AdapterView parent) { + // Do nothing. + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraConnectionFragment.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraConnectionFragment.java new file mode 100644 index 0000000000000000000000000000000000000000..13e5c0dc341a86b1ddd66c4b562e0bf767641b42 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraConnectionFragment.java @@ -0,0 +1,575 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification; + +import android.annotation.SuppressLint; +import android.app.Activity; +import android.app.AlertDialog; +import android.app.Dialog; +import android.app.DialogFragment; +import android.app.Fragment; +import android.content.Context; +import android.content.DialogInterface; +import android.content.res.Configuration; +import android.graphics.ImageFormat; +import android.graphics.Matrix; +import android.graphics.RectF; +import android.graphics.SurfaceTexture; +import android.hardware.camera2.CameraAccessException; +import android.hardware.camera2.CameraCaptureSession; +import android.hardware.camera2.CameraCharacteristics; +import android.hardware.camera2.CameraDevice; +import android.hardware.camera2.CameraManager; +import android.hardware.camera2.CaptureRequest; +import android.hardware.camera2.CaptureResult; +import android.hardware.camera2.TotalCaptureResult; +import android.hardware.camera2.params.StreamConfigurationMap; +import android.media.ImageReader; +import android.media.ImageReader.OnImageAvailableListener; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.text.TextUtils; +import android.util.Size; +import android.util.SparseIntArray; +import android.view.LayoutInflater; +import android.view.Surface; +import android.view.TextureView; +import android.view.View; +import android.view.ViewGroup; +import android.widget.Toast; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import org.tensorflow.lite.examples.classification.customview.AutoFitTextureView; +import org.tensorflow.lite.examples.classification.env.Logger; + +/** + * Camera Connection Fragment that captures images from camera. + * + *

Instantiated by newInstance.

+ */ +@SuppressWarnings("FragmentNotInstantiable") +public class CameraConnectionFragment extends Fragment { + private static final Logger LOGGER = new Logger(); + + /** + * The camera preview size will be chosen to be the smallest frame by pixel size capable of + * containing a DESIRED_SIZE x DESIRED_SIZE square. + */ + private static final int MINIMUM_PREVIEW_SIZE = 320; + + /** Conversion from screen rotation to JPEG orientation. */ + private static final SparseIntArray ORIENTATIONS = new SparseIntArray(); + + private static final String FRAGMENT_DIALOG = "dialog"; + + static { + ORIENTATIONS.append(Surface.ROTATION_0, 90); + ORIENTATIONS.append(Surface.ROTATION_90, 0); + ORIENTATIONS.append(Surface.ROTATION_180, 270); + ORIENTATIONS.append(Surface.ROTATION_270, 180); + } + + /** A {@link Semaphore} to prevent the app from exiting before closing the camera. */ + private final Semaphore cameraOpenCloseLock = new Semaphore(1); + /** A {@link OnImageAvailableListener} to receive frames as they are available. */ + private final OnImageAvailableListener imageListener; + /** The input size in pixels desired by TensorFlow (width and height of a square bitmap). */ + private final Size inputSize; + /** The layout identifier to inflate for this Fragment. */ + private final int layout; + + private final ConnectionCallback cameraConnectionCallback; + private final CameraCaptureSession.CaptureCallback captureCallback = + new CameraCaptureSession.CaptureCallback() { + @Override + public void onCaptureProgressed( + final CameraCaptureSession session, + final CaptureRequest request, + final CaptureResult partialResult) {} + + @Override + public void onCaptureCompleted( + final CameraCaptureSession session, + final CaptureRequest request, + final TotalCaptureResult result) {} + }; + /** ID of the current {@link CameraDevice}. */ + private String cameraId; + /** An {@link AutoFitTextureView} for camera preview. */ + private AutoFitTextureView textureView; + /** A {@link CameraCaptureSession } for camera preview. */ + private CameraCaptureSession captureSession; + /** A reference to the opened {@link CameraDevice}. */ + private CameraDevice cameraDevice; + /** The rotation in degrees of the camera sensor from the display. */ + private Integer sensorOrientation; + /** The {@link Size} of camera preview. */ + private Size previewSize; + /** An additional thread for running tasks that shouldn't block the UI. */ + private HandlerThread backgroundThread; + /** A {@link Handler} for running tasks in the background. */ + private Handler backgroundHandler; + /** + * {@link TextureView.SurfaceTextureListener} handles several lifecycle events on a {@link + * TextureView}. + */ + private final TextureView.SurfaceTextureListener surfaceTextureListener = + new TextureView.SurfaceTextureListener() { + @Override + public void onSurfaceTextureAvailable( + final SurfaceTexture texture, final int width, final int height) { + openCamera(width, height); + } + + @Override + public void onSurfaceTextureSizeChanged( + final SurfaceTexture texture, final int width, final int height) { + configureTransform(width, height); + } + + @Override + public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) { + return true; + } + + @Override + public void onSurfaceTextureUpdated(final SurfaceTexture texture) {} + }; + /** An {@link ImageReader} that handles preview frame capture. */ + private ImageReader previewReader; + /** {@link CaptureRequest.Builder} for the camera preview */ + private CaptureRequest.Builder previewRequestBuilder; + /** {@link CaptureRequest} generated by {@link #previewRequestBuilder} */ + private CaptureRequest previewRequest; + /** {@link CameraDevice.StateCallback} is called when {@link CameraDevice} changes its state. */ + private final CameraDevice.StateCallback stateCallback = + new CameraDevice.StateCallback() { + @Override + public void onOpened(final CameraDevice cd) { + // This method is called when the camera is opened. We start camera preview here. + cameraOpenCloseLock.release(); + cameraDevice = cd; + createCameraPreviewSession(); + } + + @Override + public void onDisconnected(final CameraDevice cd) { + cameraOpenCloseLock.release(); + cd.close(); + cameraDevice = null; + } + + @Override + public void onError(final CameraDevice cd, final int error) { + cameraOpenCloseLock.release(); + cd.close(); + cameraDevice = null; + final Activity activity = getActivity(); + if (null != activity) { + activity.finish(); + } + } + }; + + @SuppressLint("ValidFragment") + private CameraConnectionFragment( + final ConnectionCallback connectionCallback, + final OnImageAvailableListener imageListener, + final int layout, + final Size inputSize) { + this.cameraConnectionCallback = connectionCallback; + this.imageListener = imageListener; + this.layout = layout; + this.inputSize = inputSize; + } + + /** + * Given {@code choices} of {@code Size}s supported by a camera, chooses the smallest one whose + * width and height are at least as large as the minimum of both, or an exact match if possible. + * + * @param choices The list of sizes that the camera supports for the intended output class + * @param width The minimum desired width + * @param height The minimum desired height + * @return The optimal {@code Size}, or an arbitrary one if none were big enough + */ + protected static Size chooseOptimalSize(final Size[] choices, final int width, final int height) { + final int minSize = Math.max(Math.min(width, height), MINIMUM_PREVIEW_SIZE); + final Size desiredSize = new Size(width, height); + + // Collect the supported resolutions that are at least as big as the preview Surface + boolean exactSizeFound = false; + final List bigEnough = new ArrayList(); + final List tooSmall = new ArrayList(); + for (final Size option : choices) { + if (option.equals(desiredSize)) { + // Set the size but don't return yet so that remaining sizes will still be logged. + exactSizeFound = true; + } + + if (option.getHeight() >= minSize && option.getWidth() >= minSize) { + bigEnough.add(option); + } else { + tooSmall.add(option); + } + } + + LOGGER.i("Desired size: " + desiredSize + ", min size: " + minSize + "x" + minSize); + LOGGER.i("Valid preview sizes: [" + TextUtils.join(", ", bigEnough) + "]"); + LOGGER.i("Rejected preview sizes: [" + TextUtils.join(", ", tooSmall) + "]"); + + if (exactSizeFound) { + LOGGER.i("Exact size match found."); + return desiredSize; + } + + // Pick the smallest of those, assuming we found any + if (bigEnough.size() > 0) { + final Size chosenSize = Collections.min(bigEnough, new CompareSizesByArea()); + LOGGER.i("Chosen size: " + chosenSize.getWidth() + "x" + chosenSize.getHeight()); + return chosenSize; + } else { + LOGGER.e("Couldn't find any suitable preview size"); + return choices[0]; + } + } + + public static CameraConnectionFragment newInstance( + final ConnectionCallback callback, + final OnImageAvailableListener imageListener, + final int layout, + final Size inputSize) { + return new CameraConnectionFragment(callback, imageListener, layout, inputSize); + } + + /** + * Shows a {@link Toast} on the UI thread. + * + * @param text The message to show + */ + private void showToast(final String text) { + final Activity activity = getActivity(); + if (activity != null) { + activity.runOnUiThread( + new Runnable() { + @Override + public void run() { + Toast.makeText(activity, text, Toast.LENGTH_SHORT).show(); + } + }); + } + } + + @Override + public View onCreateView( + final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) { + return inflater.inflate(layout, container, false); + } + + @Override + public void onViewCreated(final View view, final Bundle savedInstanceState) { + textureView = (AutoFitTextureView) view.findViewById(R.id.texture); + } + + @Override + public void onActivityCreated(final Bundle savedInstanceState) { + super.onActivityCreated(savedInstanceState); + } + + @Override + public void onResume() { + super.onResume(); + startBackgroundThread(); + + // When the screen is turned off and turned back on, the SurfaceTexture is already + // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open + // a camera and start preview from here (otherwise, we wait until the surface is ready in + // the SurfaceTextureListener). + if (textureView.isAvailable()) { + openCamera(textureView.getWidth(), textureView.getHeight()); + } else { + textureView.setSurfaceTextureListener(surfaceTextureListener); + } + } + + @Override + public void onPause() { + closeCamera(); + stopBackgroundThread(); + super.onPause(); + } + + public void setCamera(String cameraId) { + this.cameraId = cameraId; + } + + /** Sets up member variables related to camera. */ + private void setUpCameraOutputs() { + final Activity activity = getActivity(); + final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId); + + final StreamConfigurationMap map = + characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP); + + sensorOrientation = characteristics.get(CameraCharacteristics.SENSOR_ORIENTATION); + + // Danger, W.R.! Attempting to use too large a preview size could exceed the camera + // bus' bandwidth limitation, resulting in gorgeous previews but the storage of + // garbage capture data. + previewSize = + chooseOptimalSize( + map.getOutputSizes(SurfaceTexture.class), + inputSize.getWidth(), + inputSize.getHeight()); + + // We fit the aspect ratio of TextureView to the size of preview we picked. + final int orientation = getResources().getConfiguration().orientation; + if (orientation == Configuration.ORIENTATION_LANDSCAPE) { + textureView.setAspectRatio(previewSize.getWidth(), previewSize.getHeight()); + } else { + textureView.setAspectRatio(previewSize.getHeight(), previewSize.getWidth()); + } + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } catch (final NullPointerException e) { + // Currently an NPE is thrown when the Camera2API is used but not supported on the + // device this code runs. + ErrorDialog.newInstance(getString(R.string.tfe_ic_camera_error)) + .show(getChildFragmentManager(), FRAGMENT_DIALOG); + throw new IllegalStateException(getString(R.string.tfe_ic_camera_error)); + } + + cameraConnectionCallback.onPreviewSizeChosen(previewSize, sensorOrientation); + } + + /** Opens the camera specified by {@link CameraConnectionFragment#cameraId}. */ + private void openCamera(final int width, final int height) { + setUpCameraOutputs(); + configureTransform(width, height); + final Activity activity = getActivity(); + final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) { + throw new RuntimeException("Time out waiting to lock camera opening."); + } + manager.openCamera(cameraId, stateCallback, backgroundHandler); + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } catch (final InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera opening.", e); + } + } + + /** Closes the current {@link CameraDevice}. */ + private void closeCamera() { + try { + cameraOpenCloseLock.acquire(); + if (null != captureSession) { + captureSession.close(); + captureSession = null; + } + if (null != cameraDevice) { + cameraDevice.close(); + cameraDevice = null; + } + if (null != previewReader) { + previewReader.close(); + previewReader = null; + } + } catch (final InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera closing.", e); + } finally { + cameraOpenCloseLock.release(); + } + } + + /** Starts a background thread and its {@link Handler}. */ + private void startBackgroundThread() { + backgroundThread = new HandlerThread("ImageListener"); + backgroundThread.start(); + backgroundHandler = new Handler(backgroundThread.getLooper()); + } + + /** Stops the background thread and its {@link Handler}. */ + private void stopBackgroundThread() { + backgroundThread.quitSafely(); + try { + backgroundThread.join(); + backgroundThread = null; + backgroundHandler = null; + } catch (final InterruptedException e) { + LOGGER.e(e, "Exception!"); + } + } + + /** Creates a new {@link CameraCaptureSession} for camera preview. */ + private void createCameraPreviewSession() { + try { + final SurfaceTexture texture = textureView.getSurfaceTexture(); + assert texture != null; + + // We configure the size of default buffer to be the size of camera preview we want. + texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight()); + + // This is the output Surface we need to start preview. + final Surface surface = new Surface(texture); + + // We set up a CaptureRequest.Builder with the output Surface. + previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW); + previewRequestBuilder.addTarget(surface); + + LOGGER.i("Opening camera preview: " + previewSize.getWidth() + "x" + previewSize.getHeight()); + + // Create the reader for the preview frames. + previewReader = + ImageReader.newInstance( + previewSize.getWidth(), previewSize.getHeight(), ImageFormat.YUV_420_888, 2); + + previewReader.setOnImageAvailableListener(imageListener, backgroundHandler); + previewRequestBuilder.addTarget(previewReader.getSurface()); + + // Here, we create a CameraCaptureSession for camera preview. + cameraDevice.createCaptureSession( + Arrays.asList(surface, previewReader.getSurface()), + new CameraCaptureSession.StateCallback() { + + @Override + public void onConfigured(final CameraCaptureSession cameraCaptureSession) { + // The camera is already closed + if (null == cameraDevice) { + return; + } + + // When the session is ready, we start displaying the preview. + captureSession = cameraCaptureSession; + try { + // Auto focus should be continuous for camera preview. + previewRequestBuilder.set( + CaptureRequest.CONTROL_AF_MODE, + CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE); + // Flash is automatically enabled when necessary. + previewRequestBuilder.set( + CaptureRequest.CONTROL_AE_MODE, CaptureRequest.CONTROL_AE_MODE_ON_AUTO_FLASH); + + // Finally, we start displaying the camera preview. + previewRequest = previewRequestBuilder.build(); + captureSession.setRepeatingRequest( + previewRequest, captureCallback, backgroundHandler); + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } + } + + @Override + public void onConfigureFailed(final CameraCaptureSession cameraCaptureSession) { + showToast("Failed"); + } + }, + null); + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } + } + + /** + * Configures the necessary {@link Matrix} transformation to `mTextureView`. This method should be + * called after the camera preview size is determined in setUpCameraOutputs and also the size of + * `mTextureView` is fixed. + * + * @param viewWidth The width of `mTextureView` + * @param viewHeight The height of `mTextureView` + */ + private void configureTransform(final int viewWidth, final int viewHeight) { + final Activity activity = getActivity(); + if (null == textureView || null == previewSize || null == activity) { + return; + } + final int rotation = activity.getWindowManager().getDefaultDisplay().getRotation(); + final Matrix matrix = new Matrix(); + final RectF viewRect = new RectF(0, 0, viewWidth, viewHeight); + final RectF bufferRect = new RectF(0, 0, previewSize.getHeight(), previewSize.getWidth()); + final float centerX = viewRect.centerX(); + final float centerY = viewRect.centerY(); + if (Surface.ROTATION_90 == rotation || Surface.ROTATION_270 == rotation) { + bufferRect.offset(centerX - bufferRect.centerX(), centerY - bufferRect.centerY()); + matrix.setRectToRect(viewRect, bufferRect, Matrix.ScaleToFit.FILL); + final float scale = + Math.max( + (float) viewHeight / previewSize.getHeight(), + (float) viewWidth / previewSize.getWidth()); + matrix.postScale(scale, scale, centerX, centerY); + matrix.postRotate(90 * (rotation - 2), centerX, centerY); + } else if (Surface.ROTATION_180 == rotation) { + matrix.postRotate(180, centerX, centerY); + } + textureView.setTransform(matrix); + } + + /** + * Callback for Activities to use to initialize their data once the selected preview size is + * known. + */ + public interface ConnectionCallback { + void onPreviewSizeChosen(Size size, int cameraRotation); + } + + /** Compares two {@code Size}s based on their areas. */ + static class CompareSizesByArea implements Comparator { + @Override + public int compare(final Size lhs, final Size rhs) { + // We cast here to ensure the multiplications won't overflow + return Long.signum( + (long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight()); + } + } + + /** Shows an error message dialog. */ + public static class ErrorDialog extends DialogFragment { + private static final String ARG_MESSAGE = "message"; + + public static ErrorDialog newInstance(final String message) { + final ErrorDialog dialog = new ErrorDialog(); + final Bundle args = new Bundle(); + args.putString(ARG_MESSAGE, message); + dialog.setArguments(args); + return dialog; + } + + @Override + public Dialog onCreateDialog(final Bundle savedInstanceState) { + final Activity activity = getActivity(); + return new AlertDialog.Builder(activity) + .setMessage(getArguments().getString(ARG_MESSAGE)) + .setPositiveButton( + android.R.string.ok, + new DialogInterface.OnClickListener() { + @Override + public void onClick(final DialogInterface dialogInterface, final int i) { + activity.finish(); + } + }) + .create(); + } + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..24b5d72fdb42d47e5d2c87e3f944b71105748c1b --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java @@ -0,0 +1,238 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification; + +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import android.graphics.Typeface; +import android.media.ImageReader.OnImageAvailableListener; +import android.os.SystemClock; +import android.util.Size; +import android.util.TypedValue; +import android.view.TextureView; +import android.view.ViewStub; +import android.widget.TextView; +import android.widget.Toast; +import java.io.IOException; +import java.util.List; +import java.util.ArrayList; + +import org.tensorflow.lite.examples.classification.customview.AutoFitTextureView; +import org.tensorflow.lite.examples.classification.env.BorderedText; +import org.tensorflow.lite.examples.classification.env.Logger; +import org.tensorflow.lite.examples.classification.tflite.Classifier; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Model; + +import android.widget.ImageView; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Paint; +import android.graphics.Rect; +import android.graphics.RectF; +import android.graphics.PixelFormat; +import java.nio.ByteBuffer; + +public class ClassifierActivity extends CameraActivity implements OnImageAvailableListener { + private static final Logger LOGGER = new Logger(); + private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480); + private static final float TEXT_SIZE_DIP = 10; + private Bitmap rgbFrameBitmap = null; + private long lastProcessingTimeMs; + private Integer sensorOrientation; + private Classifier classifier; + private BorderedText borderedText; + /** Input image size of the model along x axis. */ + private int imageSizeX; + /** Input image size of the model along y axis. */ + private int imageSizeY; + + @Override + protected int getLayoutId() { + return R.layout.tfe_ic_camera_connection_fragment; + } + + @Override + protected Size getDesiredPreviewFrameSize() { + return DESIRED_PREVIEW_SIZE; + } + + @Override + public void onPreviewSizeChosen(final Size size, final int rotation) { + final float textSizePx = + TypedValue.applyDimension( + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); + borderedText = new BorderedText(textSizePx); + borderedText.setTypeface(Typeface.MONOSPACE); + + recreateClassifier(getModel(), getDevice(), getNumThreads()); + if (classifier == null) { + LOGGER.e("No classifier on preview!"); + return; + } + + previewWidth = size.getWidth(); + previewHeight = size.getHeight(); + + sensorOrientation = rotation - getScreenOrientation(); + LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation); + + LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); + rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); + } + + @Override + protected void processImage() { + rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight); + final int cropSize = Math.min(previewWidth, previewHeight); + + runInBackground( + new Runnable() { + @Override + public void run() { + if (classifier != null) { + final long startTime = SystemClock.uptimeMillis(); + //final List results = + // classifier.recognizeImage(rgbFrameBitmap, sensorOrientation); + final List results = new ArrayList<>(); + + float[] img_array = classifier.recognizeImage(rgbFrameBitmap, sensorOrientation); + + + /* + float maxval = Float.NEGATIVE_INFINITY; + float minval = Float.POSITIVE_INFINITY; + for (float cur : img_array) { + maxval = Math.max(maxval, cur); + minval = Math.min(minval, cur); + } + float multiplier = 0; + if ((maxval - minval) > 0) multiplier = 255 / (maxval - minval); + + int[] img_normalized = new int[img_array.length]; + for (int i = 0; i < img_array.length; ++i) { + float val = (float) (multiplier * (img_array[i] - minval)); + img_normalized[i] = (int) val; + } + + + + TextureView textureView = findViewById(R.id.textureView3); + //AutoFitTextureView textureView = (AutoFitTextureView) findViewById(R.id.texture); + + if(textureView.isAvailable()) { + int width = imageSizeX; + int height = imageSizeY; + + Canvas canvas = textureView.lockCanvas(); + canvas.drawColor(Color.BLUE); + Paint paint = new Paint(); + paint.setStyle(Paint.Style.FILL); + paint.setARGB(255, 150, 150, 150); + + int canvas_size = Math.min(canvas.getWidth(), canvas.getHeight()); + + Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.RGB_565); + + for (int ii = 0; ii < width; ii++) //pass the screen pixels in 2 directions + { + for (int jj = 0; jj < height; jj++) { + //int val = img_normalized[ii + jj * width]; + int index = (width - ii - 1) + (height - jj - 1) * width; + if(index < img_array.length) { + int val = img_normalized[index]; + bitmap.setPixel(ii, jj, Color.rgb(val, val, val)); + } + } + } + + canvas.drawBitmap(bitmap, null, new RectF(0, 0, canvas_size, canvas_size), null); + + textureView.unlockCanvasAndPost(canvas); + + } + */ + + lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; + LOGGER.v("Detect: %s", results); + + runOnUiThread( + new Runnable() { + @Override + public void run() { + //showResultsInBottomSheet(results); + showResultsInTexture(img_array, imageSizeX, imageSizeY); + showFrameInfo(previewWidth + "x" + previewHeight); + showCropInfo(imageSizeX + "x" + imageSizeY); + showCameraResolution(cropSize + "x" + cropSize); + showRotationInfo(String.valueOf(sensorOrientation)); + showInference(lastProcessingTimeMs + "ms"); + } + }); + } + readyForNextImage(); + } + }); + } + + @Override + protected void onInferenceConfigurationChanged() { + if (rgbFrameBitmap == null) { + // Defer creation until we're getting camera frames. + return; + } + final Device device = getDevice(); + final Model model = getModel(); + final int numThreads = getNumThreads(); + runInBackground(() -> recreateClassifier(model, device, numThreads)); + } + + private void recreateClassifier(Model model, Device device, int numThreads) { + if (classifier != null) { + LOGGER.d("Closing classifier."); + classifier.close(); + classifier = null; + } + if (device == Device.GPU + && (model == Model.QUANTIZED_MOBILENET || model == Model.QUANTIZED_EFFICIENTNET)) { + LOGGER.d("Not creating classifier: GPU doesn't support quantized models."); + runOnUiThread( + () -> { + Toast.makeText(this, R.string.tfe_ic_gpu_quant_error, Toast.LENGTH_LONG).show(); + }); + return; + } + try { + LOGGER.d( + "Creating classifier (model=%s, device=%s, numThreads=%d)", model, device, numThreads); + classifier = Classifier.create(this, model, device, numThreads); + } catch (IOException | IllegalArgumentException e) { + LOGGER.e(e, "Failed to create classifier."); + runOnUiThread( + () -> { + Toast.makeText(this, e.getMessage(), Toast.LENGTH_LONG).show(); + }); + return; + } + + // Updates the input image size. + imageSizeX = classifier.getImageSizeX(); + imageSizeY = classifier.getImageSizeY(); + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/LegacyCameraConnectionFragment.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/LegacyCameraConnectionFragment.java new file mode 100644 index 0000000000000000000000000000000000000000..760fe90375450c7b1356603c83fb37a68548ca13 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/LegacyCameraConnectionFragment.java @@ -0,0 +1,203 @@ +package org.tensorflow.lite.examples.classification; + +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import android.annotation.SuppressLint; +import android.app.Fragment; +import android.graphics.SurfaceTexture; +import android.hardware.Camera; +import android.hardware.Camera.CameraInfo; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.util.Size; +import android.util.SparseIntArray; +import android.view.LayoutInflater; +import android.view.Surface; +import android.view.TextureView; +import android.view.View; +import android.view.ViewGroup; +import java.io.IOException; +import java.util.List; +import org.tensorflow.lite.examples.classification.customview.AutoFitTextureView; +import org.tensorflow.lite.examples.classification.env.ImageUtils; +import org.tensorflow.lite.examples.classification.env.Logger; + +public class LegacyCameraConnectionFragment extends Fragment { + private static final Logger LOGGER = new Logger(); + /** Conversion from screen rotation to JPEG orientation. */ + private static final SparseIntArray ORIENTATIONS = new SparseIntArray(); + + static { + ORIENTATIONS.append(Surface.ROTATION_0, 90); + ORIENTATIONS.append(Surface.ROTATION_90, 0); + ORIENTATIONS.append(Surface.ROTATION_180, 270); + ORIENTATIONS.append(Surface.ROTATION_270, 180); + } + + private Camera camera; + private Camera.PreviewCallback imageListener; + private Size desiredSize; + /** The layout identifier to inflate for this Fragment. */ + private int layout; + /** An {@link AutoFitTextureView} for camera preview. */ + private AutoFitTextureView textureView; + /** + * {@link TextureView.SurfaceTextureListener} handles several lifecycle events on a {@link + * TextureView}. + */ + private final TextureView.SurfaceTextureListener surfaceTextureListener = + new TextureView.SurfaceTextureListener() { + @Override + public void onSurfaceTextureAvailable( + final SurfaceTexture texture, final int width, final int height) { + + int index = getCameraId(); + camera = Camera.open(index); + + try { + Camera.Parameters parameters = camera.getParameters(); + List focusModes = parameters.getSupportedFocusModes(); + if (focusModes != null + && focusModes.contains(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE)) { + parameters.setFocusMode(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE); + } + List cameraSizes = parameters.getSupportedPreviewSizes(); + Size[] sizes = new Size[cameraSizes.size()]; + int i = 0; + for (Camera.Size size : cameraSizes) { + sizes[i++] = new Size(size.width, size.height); + } + Size previewSize = + CameraConnectionFragment.chooseOptimalSize( + sizes, desiredSize.getWidth(), desiredSize.getHeight()); + parameters.setPreviewSize(previewSize.getWidth(), previewSize.getHeight()); + camera.setDisplayOrientation(90); + camera.setParameters(parameters); + camera.setPreviewTexture(texture); + } catch (IOException exception) { + camera.release(); + } + + camera.setPreviewCallbackWithBuffer(imageListener); + Camera.Size s = camera.getParameters().getPreviewSize(); + camera.addCallbackBuffer(new byte[ImageUtils.getYUVByteSize(s.height, s.width)]); + + textureView.setAspectRatio(s.height, s.width); + + camera.startPreview(); + } + + @Override + public void onSurfaceTextureSizeChanged( + final SurfaceTexture texture, final int width, final int height) {} + + @Override + public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) { + return true; + } + + @Override + public void onSurfaceTextureUpdated(final SurfaceTexture texture) {} + }; + /** An additional thread for running tasks that shouldn't block the UI. */ + private HandlerThread backgroundThread; + + @SuppressLint("ValidFragment") + public LegacyCameraConnectionFragment( + final Camera.PreviewCallback imageListener, final int layout, final Size desiredSize) { + this.imageListener = imageListener; + this.layout = layout; + this.desiredSize = desiredSize; + } + + @Override + public View onCreateView( + final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) { + return inflater.inflate(layout, container, false); + } + + @Override + public void onViewCreated(final View view, final Bundle savedInstanceState) { + textureView = (AutoFitTextureView) view.findViewById(R.id.texture); + } + + @Override + public void onActivityCreated(final Bundle savedInstanceState) { + super.onActivityCreated(savedInstanceState); + } + + @Override + public void onResume() { + super.onResume(); + startBackgroundThread(); + // When the screen is turned off and turned back on, the SurfaceTexture is already + // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open + // a camera and start preview from here (otherwise, we wait until the surface is ready in + // the SurfaceTextureListener). + + if (textureView.isAvailable()) { + if (camera != null) { + camera.startPreview(); + } + } else { + textureView.setSurfaceTextureListener(surfaceTextureListener); + } + } + + @Override + public void onPause() { + stopCamera(); + stopBackgroundThread(); + super.onPause(); + } + + /** Starts a background thread and its {@link Handler}. */ + private void startBackgroundThread() { + backgroundThread = new HandlerThread("CameraBackground"); + backgroundThread.start(); + } + + /** Stops the background thread and its {@link Handler}. */ + private void stopBackgroundThread() { + backgroundThread.quitSafely(); + try { + backgroundThread.join(); + backgroundThread = null; + } catch (final InterruptedException e) { + LOGGER.e(e, "Exception!"); + } + } + + protected void stopCamera() { + if (camera != null) { + camera.stopPreview(); + camera.setPreviewCallback(null); + camera.release(); + camera = null; + } + } + + private int getCameraId() { + CameraInfo ci = new CameraInfo(); + for (int i = 0; i < Camera.getNumberOfCameras(); i++) { + Camera.getCameraInfo(i, ci); + if (ci.facing == CameraInfo.CAMERA_FACING_BACK) return i; + } + return -1; // No camera found + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/AutoFitTextureView.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/AutoFitTextureView.java new file mode 100644 index 0000000000000000000000000000000000000000..62e99ae70c2a7c4c60a776e7490742c5339e85f3 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/AutoFitTextureView.java @@ -0,0 +1,72 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification.customview; + +import android.content.Context; +import android.util.AttributeSet; +import android.view.TextureView; + +/** A {@link TextureView} that can be adjusted to a specified aspect ratio. */ +public class AutoFitTextureView extends TextureView { + private int ratioWidth = 0; + private int ratioHeight = 0; + + public AutoFitTextureView(final Context context) { + this(context, null); + } + + public AutoFitTextureView(final Context context, final AttributeSet attrs) { + this(context, attrs, 0); + } + + public AutoFitTextureView(final Context context, final AttributeSet attrs, final int defStyle) { + super(context, attrs, defStyle); + } + + /** + * Sets the aspect ratio for this view. The size of the view will be measured based on the ratio + * calculated from the parameters. Note that the actual sizes of parameters don't matter, that is, + * calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result. + * + * @param width Relative horizontal size + * @param height Relative vertical size + */ + public void setAspectRatio(final int width, final int height) { + if (width < 0 || height < 0) { + throw new IllegalArgumentException("Size cannot be negative."); + } + ratioWidth = width; + ratioHeight = height; + requestLayout(); + } + + @Override + protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) { + super.onMeasure(widthMeasureSpec, heightMeasureSpec); + final int width = MeasureSpec.getSize(widthMeasureSpec); + final int height = MeasureSpec.getSize(heightMeasureSpec); + if (0 == ratioWidth || 0 == ratioHeight) { + setMeasuredDimension(width, height); + } else { + if (width < height * ratioWidth / ratioHeight) { + setMeasuredDimension(width, width * ratioHeight / ratioWidth); + } else { + setMeasuredDimension(height * ratioWidth / ratioHeight, height); + } + } + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/OverlayView.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/OverlayView.java new file mode 100644 index 0000000000000000000000000000000000000000..dc302ac04f145c9a1673a2d7e630a94a05ab1b1a --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/OverlayView.java @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.customview; + +import android.content.Context; +import android.graphics.Canvas; +import android.util.AttributeSet; +import android.view.View; +import java.util.LinkedList; +import java.util.List; + +/** A simple View providing a render callback to other classes. */ +public class OverlayView extends View { + private final List callbacks = new LinkedList(); + + public OverlayView(final Context context, final AttributeSet attrs) { + super(context, attrs); + } + + public void addCallback(final DrawCallback callback) { + callbacks.add(callback); + } + + @Override + public synchronized void draw(final Canvas canvas) { + for (final DrawCallback callback : callbacks) { + callback.drawCallback(canvas); + } + } + + /** Interface defining the callback for client classes. */ + public interface DrawCallback { + public void drawCallback(final Canvas canvas); + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/RecognitionScoreView.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/RecognitionScoreView.java new file mode 100644 index 0000000000000000000000000000000000000000..2c57f603f12200079c888793cfa40d9b10dabde3 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/RecognitionScoreView.java @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.customview; + +import android.content.Context; +import android.graphics.Canvas; +import android.graphics.Paint; +import android.util.AttributeSet; +import android.util.TypedValue; +import android.view.View; +import java.util.List; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; + +public class RecognitionScoreView extends View implements ResultsView { + private static final float TEXT_SIZE_DIP = 16; + private final float textSizePx; + private final Paint fgPaint; + private final Paint bgPaint; + private List results; + + public RecognitionScoreView(final Context context, final AttributeSet set) { + super(context, set); + + textSizePx = + TypedValue.applyDimension( + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); + fgPaint = new Paint(); + fgPaint.setTextSize(textSizePx); + + bgPaint = new Paint(); + bgPaint.setColor(0xcc4285f4); + } + + @Override + public void setResults(final List results) { + this.results = results; + postInvalidate(); + } + + @Override + public void onDraw(final Canvas canvas) { + final int x = 10; + int y = (int) (fgPaint.getTextSize() * 1.5f); + + canvas.drawPaint(bgPaint); + + if (results != null) { + for (final Recognition recog : results) { + canvas.drawText(recog.getTitle() + ": " + recog.getConfidence(), x, y, fgPaint); + y += (int) (fgPaint.getTextSize() * 1.5f); + } + } + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/ResultsView.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/ResultsView.java new file mode 100644 index 0000000000000000000000000000000000000000..d055eb5f161a57fc439716efe6d49b7e45ef3fc7 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/ResultsView.java @@ -0,0 +1,23 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.customview; + +import java.util.List; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; + +public interface ResultsView { + public void setResults(final List results); +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 0000000000000000000000000000000000000000..b1517edf496ef5800b97d046b92012a9f94a34d0 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/bottom_sheet_bg.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/bottom_sheet_bg.xml new file mode 100644 index 0000000000000000000000000000000000000000..70f4b24e35039e6bfc35989bcbe570a4bdc2ae07 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/bottom_sheet_bg.xml @@ -0,0 +1,9 @@ + + + + + + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_add.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_add.xml new file mode 100644 index 0000000000000000000000000000000000000000..757f4503314fb9e5837f68ac515f4487d9b5fc2c --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_add.xml @@ -0,0 +1,9 @@ + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_remove.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_remove.xml new file mode 100644 index 0000000000000000000000000000000000000000..a64b853e79137f0fd95f9d5fa6e0552cc255c7ae --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_remove.xml @@ -0,0 +1,9 @@ + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_launcher_background.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 0000000000000000000000000000000000000000..d5fccc538c179838bfdce779c26eebb4fa0b5ce9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/rectangle.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/rectangle.xml new file mode 100644 index 0000000000000000000000000000000000000000..b8f5d3559c4e83072d5d73a3241d240aa68daccf --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/rectangle.xml @@ -0,0 +1,13 @@ + + + + + + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_activity_camera.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_activity_camera.xml new file mode 100644 index 0000000000000000000000000000000000000000..f0e1dae7afa15cf4a832de708f345482a6dfeff6 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_activity_camera.xml @@ -0,0 +1,56 @@ + + + + + + + + + + + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_camera_connection_fragment.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_camera_connection_fragment.xml new file mode 100644 index 0000000000000000000000000000000000000000..97e5e7c6df25da48977f9064a888fd3735e4986f --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_camera_connection_fragment.xml @@ -0,0 +1,32 @@ + + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml new file mode 100644 index 0000000000000000000000000000000000000000..77a348af90e2ed995ff106cd209cbf304c6b9153 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml @@ -0,0 +1,321 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml new file mode 100644 index 0000000000000000000000000000000000000000..0c2a915e91af65a077d2e01db4ca21acd42906f3 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml new file mode 100644 index 0000000000000000000000000000000000000000..0c2a915e91af65a077d2e01db4ca21acd42906f3 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/colors.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/colors.xml new file mode 100644 index 0000000000000000000000000000000000000000..ed82bafb536474c6a88c996b439a2781f31f3d3e --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/colors.xml @@ -0,0 +1,8 @@ + + + #ffa800 + #ff6f00 + #425066 + + #66000000 + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/dimens.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/dimens.xml new file mode 100644 index 0000000000000000000000000000000000000000..5d3609029ca66b612c88b4f395e4e2e3cfc1f0e6 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/dimens.xml @@ -0,0 +1,5 @@ + + + 15dp + 8dp + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/strings.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/strings.xml new file mode 100644 index 0000000000000000000000000000000000000000..7d763d85efc49879c8d3c0641484f5f472bfaca0 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/strings.xml @@ -0,0 +1,21 @@ + + Midas + This device doesn\'t support Camera2 API. + GPU does not yet supported quantized models. + Model: + + Float_EfficientNet + + + + Device: + + GPU + CPU + NNAPI + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/styles.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/styles.xml new file mode 100644 index 0000000000000000000000000000000000000000..ad09a13ec6b2de8920a7441c9992f3cc0eedcfda --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/styles.xml @@ -0,0 +1,11 @@ + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/build.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..14492756847191ca3beff4c2e012d378c4e44be6 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/build.gradle @@ -0,0 +1,27 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. + +buildscript { + + repositories { + google() + jcenter() + } + dependencies { + classpath 'com.android.tools.build:gradle:4.0.0' + classpath 'de.undercouch:gradle-download-task:4.0.2' + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files + } +} + +allprojects { + repositories { + google() + jcenter() + } +} + +task clean(type: Delete) { + delete rootProject.buildDir +} + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle.properties b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle.properties new file mode 100644 index 0000000000000000000000000000000000000000..9592636c07d9d5e6f61c0cfce1311d3e1ffcf34d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle.properties @@ -0,0 +1,15 @@ +# Project-wide Gradle settings. +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx1536m +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true +android.useAndroidX=true +android.enableJetifier=true diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.jar b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000000000000000000000000000000000000..f3d88b1c2faf2fc91d853cd5d4242b5547257070 Binary files /dev/null and b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.jar differ diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.properties b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000000000000000000000000000000000000..1b16c34a71cf212ed0cfb883d14d1b8511903eb2 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,5 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-6.1.1-bin.zip +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew new file mode 100644 index 0000000000000000000000000000000000000000..2fe81a7d95e4f9ad2c9b2a046707d36ceb3980b3 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew @@ -0,0 +1,183 @@ +#!/usr/bin/env sh + +# +# Copyright 2015 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn () { + echo "$*" +} + +die () { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; + NONSTOP* ) + nonstop=true + ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin or MSYS, switch paths to Windows format before running java +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=`expr $i + 1` + done + case $i in + 0) set -- ;; + 1) set -- "$args0" ;; + 2) set -- "$args0" "$args1" ;; + 3) set -- "$args0" "$args1" "$args2" ;; + 4) set -- "$args0" "$args1" "$args2" "$args3" ;; + 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " +} +APP_ARGS=`save "$@"` + +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +exec "$JAVACMD" "$@" diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew.bat b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew.bat new file mode 100644 index 0000000000000000000000000000000000000000..9618d8d9607cd91a0efb866bcac4810064ba6fac --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew.bat @@ -0,0 +1,100 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windows variants + +if not "%OS%" == "Windows_NT" goto win9xME_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/build.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..5d463975293264765a941795601cddb6cfc84f00 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/build.gradle @@ -0,0 +1,47 @@ +apply plugin: 'com.android.library' + +android { + compileSdkVersion 28 + buildToolsVersion "28.0.0" + + defaultConfig { + minSdkVersion 21 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + + aaptOptions { + noCompress "tflite" + } + + lintOptions { + checkReleaseBuilds false + // Or, if you prefer, you can continue to check for errors in release builds, + // but continue the build even when errors are found: + abortOnError false + } +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar']) + implementation project(":models") + implementation 'androidx.appcompat:appcompat:1.1.0' + + // Build off of nightly TensorFlow Lite + implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true } + implementation('org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly') { changing = true } + implementation('org.tensorflow:tensorflow-lite-support:0.0.0-nightly') { changing = true } + // Use local TensorFlow library + // implementation 'org.tensorflow:tensorflow-lite-local:0.0.0' +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/proguard-rules.pro b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/proguard-rules.pro new file mode 100644 index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/AndroidManifest.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..ebe3c56c60a9b67eec218d969aecfdf5311d7b49 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/AndroidManifest.xml @@ -0,0 +1,3 @@ + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java new file mode 100644 index 0000000000000000000000000000000000000000..24ec573e7d184e7d64118a723d6645fd92d6e6d9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java @@ -0,0 +1,376 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import static java.lang.Math.min; + +import android.app.Activity; +import android.graphics.Bitmap; +import android.graphics.RectF; +import android.os.SystemClock; +import android.os.Trace; +import android.util.Log; +import android.view.TextureView; +import android.view.ViewStub; + +import java.io.IOException; +import java.nio.MappedByteBuffer; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.Interpreter; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.gpu.GpuDelegate; +import org.tensorflow.lite.nnapi.NnApiDelegate; +import org.tensorflow.lite.support.common.FileUtil; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.TensorProcessor; +import org.tensorflow.lite.support.image.ImageProcessor; +import org.tensorflow.lite.support.image.TensorImage; +import org.tensorflow.lite.support.image.ops.ResizeOp; +import org.tensorflow.lite.support.image.ops.ResizeOp.ResizeMethod; +import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp; +import org.tensorflow.lite.support.image.ops.Rot90Op; +import org.tensorflow.lite.support.label.TensorLabel; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** A classifier specialized to label images using TensorFlow Lite. */ +public abstract class Classifier { + public static final String TAG = "ClassifierWithSupport"; + + /** The model type used for classification. */ + public enum Model { + FLOAT_MOBILENET, + QUANTIZED_MOBILENET, + QUANTIZED_EFFICIENTNET, + FLOAT_EFFICIENTNET + } + + /** The runtime device type used for executing classification. */ + public enum Device { + CPU, + NNAPI, + GPU + } + + /** Number of results to show in the UI. */ + private static final int MAX_RESULTS = 3; + + /** The loaded TensorFlow Lite model. */ + + /** Image size along the x axis. */ + private final int imageSizeX; + + /** Image size along the y axis. */ + private final int imageSizeY; + + /** Optional GPU delegate for accleration. */ + private GpuDelegate gpuDelegate = null; + + /** Optional NNAPI delegate for accleration. */ + private NnApiDelegate nnApiDelegate = null; + + /** An instance of the driver class to run model inference with Tensorflow Lite. */ + protected Interpreter tflite; + + /** Options for configuring the Interpreter. */ + private final Interpreter.Options tfliteOptions = new Interpreter.Options(); + + /** Labels corresponding to the output of the vision model. */ + private final List labels; + + /** Input image TensorBuffer. */ + private TensorImage inputImageBuffer; + + /** Output probability TensorBuffer. */ + private final TensorBuffer outputProbabilityBuffer; + + /** Processer to apply post processing of the output probability. */ + private final TensorProcessor probabilityProcessor; + + /** + * Creates a classifier with the provided configuration. + * + * @param activity The current Activity. + * @param model The model to use for classification. + * @param device The device to use for classification. + * @param numThreads The number of threads to use for classification. + * @return A classifier with the desired configuration. + */ + public static Classifier create(Activity activity, Model model, Device device, int numThreads) + throws IOException { + if (model == Model.QUANTIZED_MOBILENET) { + return new ClassifierQuantizedMobileNet(activity, device, numThreads); + } else if (model == Model.FLOAT_MOBILENET) { + return new ClassifierFloatMobileNet(activity, device, numThreads); + } else if (model == Model.FLOAT_EFFICIENTNET) { + return new ClassifierFloatEfficientNet(activity, device, numThreads); + } else if (model == Model.QUANTIZED_EFFICIENTNET) { + return new ClassifierQuantizedEfficientNet(activity, device, numThreads); + } else { + throw new UnsupportedOperationException(); + } + } + + /** An immutable result returned by a Classifier describing what was recognized. */ + public static class Recognition { + /** + * A unique identifier for what has been recognized. Specific to the class, not the instance of + * the object. + */ + private final String id; + + /** Display name for the recognition. */ + private final String title; + + /** + * A sortable score for how good the recognition is relative to others. Higher should be better. + */ + private final Float confidence; + + /** Optional location within the source image for the location of the recognized object. */ + private RectF location; + + public Recognition( + final String id, final String title, final Float confidence, final RectF location) { + this.id = id; + this.title = title; + this.confidence = confidence; + this.location = location; + } + + public String getId() { + return id; + } + + public String getTitle() { + return title; + } + + public Float getConfidence() { + return confidence; + } + + public RectF getLocation() { + return new RectF(location); + } + + public void setLocation(RectF location) { + this.location = location; + } + + @Override + public String toString() { + String resultString = ""; + if (id != null) { + resultString += "[" + id + "] "; + } + + if (title != null) { + resultString += title + " "; + } + + if (confidence != null) { + resultString += String.format("(%.1f%%) ", confidence * 100.0f); + } + + if (location != null) { + resultString += location + " "; + } + + return resultString.trim(); + } + } + + /** Initializes a {@code Classifier}. */ + protected Classifier(Activity activity, Device device, int numThreads) throws IOException { + MappedByteBuffer tfliteModel = FileUtil.loadMappedFile(activity, getModelPath()); + switch (device) { + case NNAPI: + nnApiDelegate = new NnApiDelegate(); + tfliteOptions.addDelegate(nnApiDelegate); + break; + case GPU: + gpuDelegate = new GpuDelegate(); + tfliteOptions.addDelegate(gpuDelegate); + break; + case CPU: + break; + } + tfliteOptions.setNumThreads(numThreads); + tflite = new Interpreter(tfliteModel, tfliteOptions); + + // Loads labels out from the label file. + labels = FileUtil.loadLabels(activity, getLabelPath()); + + // Reads type and shape of input and output tensors, respectively. + int imageTensorIndex = 0; + int[] imageShape = tflite.getInputTensor(imageTensorIndex).shape(); // {1, height, width, 3} + if(imageShape[1] != imageShape[2]) { + imageSizeY = imageShape[2]; + imageSizeX = imageShape[3]; + } else { + imageSizeY = imageShape[1]; + imageSizeX = imageShape[2]; + } + DataType imageDataType = tflite.getInputTensor(imageTensorIndex).dataType(); + int probabilityTensorIndex = 0; + int[] probabilityShape = + tflite.getOutputTensor(probabilityTensorIndex).shape(); // {1, NUM_CLASSES} + DataType probabilityDataType = tflite.getOutputTensor(probabilityTensorIndex).dataType(); + + // Creates the input tensor. + inputImageBuffer = new TensorImage(imageDataType); + + // Creates the output tensor and its processor. + outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType); + + // Creates the post processor for the output probability. + probabilityProcessor = new TensorProcessor.Builder().add(getPostprocessNormalizeOp()).build(); + + Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); + } + + /** Runs inference and returns the classification results. */ + //public List recognizeImage(final Bitmap bitmap, int sensorOrientation) { + public float[] recognizeImage(final Bitmap bitmap, int sensorOrientation) { + // Logs this method so that it can be analyzed with systrace. + Trace.beginSection("recognizeImage"); + + Trace.beginSection("loadImage"); + long startTimeForLoadImage = SystemClock.uptimeMillis(); + inputImageBuffer = loadImage(bitmap, sensorOrientation); + long endTimeForLoadImage = SystemClock.uptimeMillis(); + Trace.endSection(); + Log.v(TAG, "Timecost to load the image: " + (endTimeForLoadImage - startTimeForLoadImage)); + + // Runs the inference call. + Trace.beginSection("runInference"); + long startTimeForReference = SystemClock.uptimeMillis(); + tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind()); + long endTimeForReference = SystemClock.uptimeMillis(); + Trace.endSection(); + Log.v(TAG, "Timecost to run model inference: " + (endTimeForReference - startTimeForReference)); + + float[] img_array = outputProbabilityBuffer.getFloatArray(); + + // Gets the map of label and probability. + //Map labeledProbability = + // new TensorLabel(labels, probabilityProcessor.process(outputProbabilityBuffer)) + // .getMapWithFloatValue(); + Trace.endSection(); + + // Gets top-k results. + return img_array;//getTopKProbability(labeledProbability); + } + + /** Closes the interpreter and model to release resources. */ + public void close() { + if (tflite != null) { + tflite.close(); + tflite = null; + } + if (gpuDelegate != null) { + gpuDelegate.close(); + gpuDelegate = null; + } + if (nnApiDelegate != null) { + nnApiDelegate.close(); + nnApiDelegate = null; + } + } + + /** Get the image size along the x axis. */ + public int getImageSizeX() { + return imageSizeX; + } + + /** Get the image size along the y axis. */ + public int getImageSizeY() { + return imageSizeY; + } + + /** Loads input image, and applies preprocessing. */ + private TensorImage loadImage(final Bitmap bitmap, int sensorOrientation) { + // Loads bitmap into a TensorImage. + inputImageBuffer.load(bitmap); + + // Creates processor for the TensorImage. + int cropSize = min(bitmap.getWidth(), bitmap.getHeight()); + int numRotation = sensorOrientation / 90; + // TODO(b/143564309): Fuse ops inside ImageProcessor. + ImageProcessor imageProcessor = + new ImageProcessor.Builder() + .add(new ResizeWithCropOrPadOp(cropSize, cropSize)) + // TODO(b/169379396): investigate the impact of the resize algorithm on accuracy. + // To get the same inference results as lib_task_api, which is built on top of the Task + // Library, use ResizeMethod.BILINEAR. + .add(new ResizeOp(imageSizeX, imageSizeY, ResizeMethod.NEAREST_NEIGHBOR)) + //.add(new ResizeOp(224, 224, ResizeMethod.NEAREST_NEIGHBOR)) + .add(new Rot90Op(numRotation)) + .add(getPreprocessNormalizeOp()) + .build(); + return imageProcessor.process(inputImageBuffer); + } + + /** Gets the top-k results. */ + private static List getTopKProbability(Map labelProb) { + // Find the best classifications. + PriorityQueue pq = + new PriorityQueue<>( + MAX_RESULTS, + new Comparator() { + @Override + public int compare(Recognition lhs, Recognition rhs) { + // Intentionally reversed to put high confidence at the head of the queue. + return Float.compare(rhs.getConfidence(), lhs.getConfidence()); + } + }); + + for (Map.Entry entry : labelProb.entrySet()) { + pq.add(new Recognition("" + entry.getKey(), entry.getKey(), entry.getValue(), null)); + } + + final ArrayList recognitions = new ArrayList<>(); + int recognitionsSize = min(pq.size(), MAX_RESULTS); + for (int i = 0; i < recognitionsSize; ++i) { + recognitions.add(pq.poll()); + } + return recognitions; + } + + /** Gets the name of the model file stored in Assets. */ + protected abstract String getModelPath(); + + /** Gets the name of the label file stored in Assets. */ + protected abstract String getLabelPath(); + + /** Gets the TensorOperator to nomalize the input image in preprocessing. */ + protected abstract TensorOperator getPreprocessNormalizeOp(); + + /** + * Gets the TensorOperator to dequantize the output probability in post processing. + * + *

For quantized model, we need de-quantize the prediction with NormalizeOp (as they are all + * essentially linear transformation). For float model, de-quantize is not required. But to + * uniform the API, de-quantize is added to float model too. Mean and std are set to 0.0f and + * 1.0f, respectively. + */ + protected abstract TensorOperator getPostprocessNormalizeOp(); +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java new file mode 100644 index 0000000000000000000000000000000000000000..14dd027b26baefaedd979a8ac37f0bf984210ed4 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java @@ -0,0 +1,71 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.ops.NormalizeOp; + +/** This TensorFlowLite classifier works with the float EfficientNet model. */ +public class ClassifierFloatEfficientNet extends Classifier { + + private static final float IMAGE_MEAN = 115.0f; //127.0f; + private static final float IMAGE_STD = 58.0f; //128.0f; + + /** + * Float model does not need dequantization in the post-processing. Setting mean and std as 0.0f + * and 1.0f, repectively, to bypass the normalization. + */ + private static final float PROBABILITY_MEAN = 0.0f; + + private static final float PROBABILITY_STD = 1.0f; + + /** + * Initializes a {@code ClassifierFloatMobileNet}. + * + * @param activity + */ + public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + //return "efficientnet-lite0-fp32.tflite"; + return "model_opt.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels_without_background.txt"; + } + + @Override + protected TensorOperator getPreprocessNormalizeOp() { + return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); + } + + @Override + protected TensorOperator getPostprocessNormalizeOp() { + return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..40519de07cf5e887773250a4609a832b6060d684 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java @@ -0,0 +1,72 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.ops.NormalizeOp; + +/** This TensorFlowLite classifier works with the float MobileNet model. */ +public class ClassifierFloatMobileNet extends Classifier { + + /** Float MobileNet requires additional normalization of the used input. */ + private static final float IMAGE_MEAN = 127.5f; + + private static final float IMAGE_STD = 127.5f; + + /** + * Float model does not need dequantization in the post-processing. Setting mean and std as 0.0f + * and 1.0f, repectively, to bypass the normalization. + */ + private static final float PROBABILITY_MEAN = 0.0f; + + private static final float PROBABILITY_STD = 1.0f; + + /** + * Initializes a {@code ClassifierFloatMobileNet}. + * + * @param activity + */ + public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "model_0.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels.txt"; + } + + @Override + protected TensorOperator getPreprocessNormalizeOp() { + return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); + } + + @Override + protected TensorOperator getPostprocessNormalizeOp() { + return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java new file mode 100644 index 0000000000000000000000000000000000000000..d0d62f58d18333b6360ec30a4c85c9f1d38955ce --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.ops.NormalizeOp; + +/** This TensorFlow Lite classifier works with the quantized EfficientNet model. */ +public class ClassifierQuantizedEfficientNet extends Classifier { + + /** + * The quantized model does not require normalization, thus set mean as 0.0f, and std as 1.0f to + * bypass the normalization. + */ + private static final float IMAGE_MEAN = 0.0f; + + private static final float IMAGE_STD = 1.0f; + + /** Quantized MobileNet requires additional dequantization to the output probability. */ + private static final float PROBABILITY_MEAN = 0.0f; + + private static final float PROBABILITY_STD = 255.0f; + + /** + * Initializes a {@code ClassifierQuantizedMobileNet}. + * + * @param activity + */ + public ClassifierQuantizedEfficientNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "model_quant.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels_without_background.txt"; + } + + @Override + protected TensorOperator getPreprocessNormalizeOp() { + return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); + } + + @Override + protected TensorOperator getPostprocessNormalizeOp() { + return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..94b06e3df659005c287733a8a37672863fdadd71 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.ops.NormalizeOp; + +/** This TensorFlow Lite classifier works with the quantized MobileNet model. */ +public class ClassifierQuantizedMobileNet extends Classifier { + + /** + * The quantized model does not require normalization, thus set mean as 0.0f, and std as 1.0f to + * bypass the normalization. + */ + private static final float IMAGE_MEAN = 0.0f; + + private static final float IMAGE_STD = 1.0f; + + /** Quantized MobileNet requires additional dequantization to the output probability. */ + private static final float PROBABILITY_MEAN = 0.0f; + + private static final float PROBABILITY_STD = 255.0f; + + /** + * Initializes a {@code ClassifierQuantizedMobileNet}. + * + * @param activity + */ + public ClassifierQuantizedMobileNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "model_quant_0.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels.txt"; + } + + @Override + protected TensorOperator getPreprocessNormalizeOp() { + return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); + } + + @Override + protected TensorOperator getPostprocessNormalizeOp() { + return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/build.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..b5983986e3d56a77a41676b9195b0d0882b5fb96 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/build.gradle @@ -0,0 +1,47 @@ +apply plugin: 'com.android.library' + +android { + compileSdkVersion 28 + buildToolsVersion "28.0.0" + + defaultConfig { + minSdkVersion 21 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility = '1.8' + targetCompatibility = '1.8' + } + aaptOptions { + noCompress "tflite" + } + + lintOptions { + checkReleaseBuilds false + // Or, if you prefer, you can continue to check for errors in release builds, + // but continue the build even when errors are found: + abortOnError false + } +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar']) + implementation project(":models") + implementation 'androidx.appcompat:appcompat:1.1.0' + + // Build off of nightly TensorFlow Lite Task Library + implementation('org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly') { changing = true } + implementation('org.tensorflow:tensorflow-lite-metadata:0.0.0-nightly') { changing = true } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/proguard-rules.pro b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/proguard-rules.pro new file mode 100644 index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/AndroidManifest.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..ebe3c56c60a9b67eec218d969aecfdf5311d7b49 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/AndroidManifest.xml @@ -0,0 +1,3 @@ + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java new file mode 100644 index 0000000000000000000000000000000000000000..45da52a0d0dfa203255e0f2d44901ee0618e739f --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java @@ -0,0 +1,278 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import static java.lang.Math.min; + +import android.app.Activity; +import android.graphics.Bitmap; +import android.graphics.Rect; +import android.graphics.RectF; +import android.os.SystemClock; +import android.os.Trace; +import android.util.Log; +import java.io.IOException; +import java.nio.MappedByteBuffer; +import java.util.ArrayList; +import java.util.List; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.support.common.FileUtil; +import org.tensorflow.lite.support.image.TensorImage; +import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.support.metadata.MetadataExtractor; +import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; +import org.tensorflow.lite.task.core.vision.ImageProcessingOptions.Orientation; +import org.tensorflow.lite.task.vision.classifier.Classifications; +import org.tensorflow.lite.task.vision.classifier.ImageClassifier; +import org.tensorflow.lite.task.vision.classifier.ImageClassifier.ImageClassifierOptions; + +/** A classifier specialized to label images using TensorFlow Lite. */ +public abstract class Classifier { + public static final String TAG = "ClassifierWithTaskApi"; + + /** The model type used for classification. */ + public enum Model { + FLOAT_MOBILENET, + QUANTIZED_MOBILENET, + FLOAT_EFFICIENTNET, + QUANTIZED_EFFICIENTNET + } + + /** The runtime device type used for executing classification. */ + public enum Device { + CPU, + NNAPI, + GPU + } + + /** Number of results to show in the UI. */ + private static final int MAX_RESULTS = 3; + + /** Image size along the x axis. */ + private final int imageSizeX; + + /** Image size along the y axis. */ + private final int imageSizeY; + /** An instance of the driver class to run model inference with Tensorflow Lite. */ + protected final ImageClassifier imageClassifier; + + /** + * Creates a classifier with the provided configuration. + * + * @param activity The current Activity. + * @param model The model to use for classification. + * @param device The device to use for classification. + * @param numThreads The number of threads to use for classification. + * @return A classifier with the desired configuration. + */ + public static Classifier create(Activity activity, Model model, Device device, int numThreads) + throws IOException { + if (model == Model.QUANTIZED_MOBILENET) { + return new ClassifierQuantizedMobileNet(activity, device, numThreads); + } else if (model == Model.FLOAT_MOBILENET) { + return new ClassifierFloatMobileNet(activity, device, numThreads); + } else if (model == Model.FLOAT_EFFICIENTNET) { + return new ClassifierFloatEfficientNet(activity, device, numThreads); + } else if (model == Model.QUANTIZED_EFFICIENTNET) { + return new ClassifierQuantizedEfficientNet(activity, device, numThreads); + } else { + throw new UnsupportedOperationException(); + } + } + + /** An immutable result returned by a Classifier describing what was recognized. */ + public static class Recognition { + /** + * A unique identifier for what has been recognized. Specific to the class, not the instance of + * the object. + */ + private final String id; + + /** Display name for the recognition. */ + private final String title; + + /** + * A sortable score for how good the recognition is relative to others. Higher should be better. + */ + private final Float confidence; + + /** Optional location within the source image for the location of the recognized object. */ + private RectF location; + + public Recognition( + final String id, final String title, final Float confidence, final RectF location) { + this.id = id; + this.title = title; + this.confidence = confidence; + this.location = location; + } + + public String getId() { + return id; + } + + public String getTitle() { + return title; + } + + public Float getConfidence() { + return confidence; + } + + public RectF getLocation() { + return new RectF(location); + } + + public void setLocation(RectF location) { + this.location = location; + } + + @Override + public String toString() { + String resultString = ""; + if (id != null) { + resultString += "[" + id + "] "; + } + + if (title != null) { + resultString += title + " "; + } + + if (confidence != null) { + resultString += String.format("(%.1f%%) ", confidence * 100.0f); + } + + if (location != null) { + resultString += location + " "; + } + + return resultString.trim(); + } + } + + /** Initializes a {@code Classifier}. */ + protected Classifier(Activity activity, Device device, int numThreads) throws IOException { + if (device != Device.CPU || numThreads != 1) { + throw new IllegalArgumentException( + "Manipulating the hardware accelerators and numbers of threads is not allowed in the Task" + + " library currently. Only CPU + single thread is allowed."); + } + + // Create the ImageClassifier instance. + ImageClassifierOptions options = + ImageClassifierOptions.builder().setMaxResults(MAX_RESULTS).build(); + imageClassifier = ImageClassifier.createFromFileAndOptions(activity, getModelPath(), options); + Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); + + // Get the input image size information of the underlying tflite model. + MappedByteBuffer tfliteModel = FileUtil.loadMappedFile(activity, getModelPath()); + MetadataExtractor metadataExtractor = new MetadataExtractor(tfliteModel); + // Image shape is in the format of {1, height, width, 3}. + int[] imageShape = metadataExtractor.getInputTensorShape(/*inputIndex=*/ 0); + imageSizeY = imageShape[1]; + imageSizeX = imageShape[2]; + } + + /** Runs inference and returns the classification results. */ + public List recognizeImage(final Bitmap bitmap, int sensorOrientation) { + // Logs this method so that it can be analyzed with systrace. + Trace.beginSection("recognizeImage"); + + TensorImage inputImage = TensorImage.fromBitmap(bitmap); + int width = bitmap.getWidth(); + int height = bitmap.getHeight(); + int cropSize = min(width, height); + // TODO(b/169379396): investigate the impact of the resize algorithm on accuracy. + // Task Library resize the images using bilinear interpolation, which is slightly different from + // the nearest neighbor sampling algorithm used in lib_support. See + // https://github.com/tensorflow/examples/blob/0ef3d93e2af95d325c70ef3bcbbd6844d0631e07/lite/examples/image_classification/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java#L310. + ImageProcessingOptions imageOptions = + ImageProcessingOptions.builder() + .setOrientation(getOrientation(sensorOrientation)) + // Set the ROI to the center of the image. + .setRoi( + new Rect( + /*left=*/ (width - cropSize) / 2, + /*top=*/ (height - cropSize) / 2, + /*right=*/ (width + cropSize) / 2, + /*bottom=*/ (height + cropSize) / 2)) + .build(); + + // Runs the inference call. + Trace.beginSection("runInference"); + long startTimeForReference = SystemClock.uptimeMillis(); + List results = imageClassifier.classify(inputImage, imageOptions); + long endTimeForReference = SystemClock.uptimeMillis(); + Trace.endSection(); + Log.v(TAG, "Timecost to run model inference: " + (endTimeForReference - startTimeForReference)); + + Trace.endSection(); + + return getRecognitions(results); + } + + /** Closes the interpreter and model to release resources. */ + public void close() { + if (imageClassifier != null) { + imageClassifier.close(); + } + } + + /** Get the image size along the x axis. */ + public int getImageSizeX() { + return imageSizeX; + } + + /** Get the image size along the y axis. */ + public int getImageSizeY() { + return imageSizeY; + } + + /** + * Converts a list of {@link Classifications} objects into a list of {@link Recognition} objects + * to match the interface of other inference method, such as using the TFLite + * Support Library.. + */ + private static List getRecognitions(List classifications) { + + final ArrayList recognitions = new ArrayList<>(); + // All the demo models are single head models. Get the first Classifications in the results. + for (Category category : classifications.get(0).getCategories()) { + recognitions.add( + new Recognition( + "" + category.getLabel(), category.getLabel(), category.getScore(), null)); + } + return recognitions; + } + + /* Convert the camera orientation in degree into {@link ImageProcessingOptions#Orientation}.*/ + private static Orientation getOrientation(int cameraOrientation) { + switch (cameraOrientation / 90) { + case 3: + return Orientation.BOTTOM_LEFT; + case 2: + return Orientation.BOTTOM_RIGHT; + case 1: + return Orientation.TOP_RIGHT; + default: + return Orientation.TOP_LEFT; + } + } + + /** Gets the name of the model file stored in Assets. */ + protected abstract String getModelPath(); +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java new file mode 100644 index 0000000000000000000000000000000000000000..250794cc12d0e603aa47502322dc646d50689848 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java @@ -0,0 +1,45 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; + +/** This TensorFlowLite classifier works with the float EfficientNet model. */ +public class ClassifierFloatEfficientNet extends Classifier { + + /** + * Initializes a {@code ClassifierFloatMobileNet}. + * + * @param device a {@link Device} object to configure the hardware accelerator + * @param numThreads the number of threads during the inference + * @throws IOException if the model is not loaded correctly + */ + public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + //return "efficientnet-lite0-fp32.tflite"; + return "model.tflite"; + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..0707de98de41395eaf3ddcfd74d6e36229a63760 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; + +/** This TensorFlowLite classifier works with the float MobileNet model. */ +public class ClassifierFloatMobileNet extends Classifier { + /** + * Initializes a {@code ClassifierFloatMobileNet}. + * + * @param device a {@link Device} object to configure the hardware accelerator + * @param numThreads the number of threads during the inference + * @throws IOException if the model is not loaded correctly + */ + public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "mobilenet_v1_1.0_224.tflite"; + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java new file mode 100644 index 0000000000000000000000000000000000000000..05ca4fa6c409d0274a396c9b26c3c39ca8a8194e --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; + +/** This TensorFlow Lite classifier works with the quantized EfficientNet model. */ +public class ClassifierQuantizedEfficientNet extends Classifier { + + /** + * Initializes a {@code ClassifierQuantizedMobileNet}. + * + * @param device a {@link Device} object to configure the hardware accelerator + * @param numThreads the number of threads during the inference + * @throws IOException if the model is not loaded correctly + */ + public ClassifierQuantizedEfficientNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "efficientnet-lite0-int8.tflite"; + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..978b08eeaf52a23eede437d61045db08d1dff163 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; + +/** This TensorFlow Lite classifier works with the quantized MobileNet model. */ +public class ClassifierQuantizedMobileNet extends Classifier { + + /** + * Initializes a {@code ClassifierQuantizedMobileNet}. + * + * @param device a {@link Device} object to configure the hardware accelerator + * @param numThreads the number of threads during the inference + * @throws IOException if the model is not loaded correctly + */ + public ClassifierQuantizedMobileNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "mobilenet_v1_1.0_224_quant.tflite"; + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/build.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..8d825707af20cbbead6c4599f075599148e3511c --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/build.gradle @@ -0,0 +1,40 @@ +apply plugin: 'com.android.library' +apply plugin: 'de.undercouch.download' + +android { + compileSdkVersion 28 + buildToolsVersion "28.0.0" + + defaultConfig { + minSdkVersion 21 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + + aaptOptions { + noCompress "tflite" + } + + lintOptions { + checkReleaseBuilds false + // Or, if you prefer, you can continue to check for errors in release builds, + // but continue the build even when errors are found: + abortOnError false + } +} + +// Download default models; if you wish to use your own models then +// place them in the "assets" directory and comment out this line. +project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets' +apply from:'download.gradle' diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/download.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/download.gradle new file mode 100644 index 0000000000000000000000000000000000000000..ce76974a2c3bc6f8214461028e0dfa6ebc25d588 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/download.gradle @@ -0,0 +1,10 @@ +def modelFloatDownloadUrl = "https://github.com/isl-org/MiDaS/releases/download/v2_1/model_opt.tflite" +def modelFloatFile = "model_opt.tflite" + +task downloadModelFloat(type: Download) { + src "${modelFloatDownloadUrl}" + dest project.ext.ASSET_DIR + "/${modelFloatFile}" + overwrite false +} + +preBuild.dependsOn downloadModelFloat diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/proguard-rules.pro b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/proguard-rules.pro new file mode 100644 index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/AndroidManifest.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..42951a56497c5f947efe4aea6a07462019fb152c --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/AndroidManifest.xml @@ -0,0 +1,3 @@ + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/settings.gradle b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/settings.gradle new file mode 100644 index 0000000000000000000000000000000000000000..e86d89d2483f92b7e778589011fad60fbba3a318 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/settings.gradle @@ -0,0 +1,2 @@ +rootProject.name = 'TFLite Image Classification Demo App' +include ':app', ':lib_support', ':lib_task_api', ':models' \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/.gitignore b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f1150e3379e4a38d31ca7bb46dc4f31d79f482c2 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/.gitignore @@ -0,0 +1,2 @@ +# ignore model file +#*.tflite diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/LICENSE b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.pbxproj b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.pbxproj new file mode 100644 index 0000000000000000000000000000000000000000..4917371aa33a65fdfc66c02d914f05489c446430 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.pbxproj @@ -0,0 +1,538 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 50; + objects = { + +/* Begin PBXBuildFile section */ + 0CDA8C85042ADF65D0787629 /* Pods_Midas.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = A1CE41C09920CCEC31985547 /* Pods_Midas.framework */; }; + 8402440123D9834600704ABD /* README.md in Resources */ = {isa = PBXBuildFile; fileRef = 8402440023D9834600704ABD /* README.md */; }; + 840ECB20238BAA2300C7D88A /* InfoCell.swift in Sources */ = {isa = PBXBuildFile; fileRef = 840ECB1F238BAA2300C7D88A /* InfoCell.swift */; }; + 840EDCFD2341DDD30017ED42 /* Launch Screen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 840EDCFB2341DDD30017ED42 /* Launch Screen.storyboard */; }; + 840EDD022341DE380017ED42 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 840EDD002341DE380017ED42 /* Main.storyboard */; }; + 842DDB6E2372A82000F6BB94 /* OverlayView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 842DDB6D2372A82000F6BB94 /* OverlayView.swift */; }; + 846499C2235DAB0D009CBBC7 /* ModelDataHandler.swift in Sources */ = {isa = PBXBuildFile; fileRef = 846499C1235DAB0D009CBBC7 /* ModelDataHandler.swift */; }; + 846BAF7623E7FE13006FC136 /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = 846BAF7523E7FE13006FC136 /* Constants.swift */; }; + 8474FEC92341D36E00377D34 /* PreviewView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8474FEC82341D36E00377D34 /* PreviewView.swift */; }; + 8474FECB2341D39800377D34 /* CameraFeedManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8474FECA2341D39800377D34 /* CameraFeedManager.swift */; }; + 84952CB5236186BE0052C104 /* CVPixelBufferExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84952CB4236186BE0052C104 /* CVPixelBufferExtension.swift */; }; + 84952CB92361874A0052C104 /* TFLiteExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84952CB82361874A0052C104 /* TFLiteExtension.swift */; }; + 84B67CEF2326338300A11A08 /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84B67CEE2326338300A11A08 /* AppDelegate.swift */; }; + 84B67CF12326338300A11A08 /* ViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84B67CF02326338300A11A08 /* ViewController.swift */; }; + 84B67CF62326338400A11A08 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 84B67CF52326338400A11A08 /* Assets.xcassets */; }; + 84D6576D2387BB7E0048171E /* CGSizeExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84D6576C2387BB7E0048171E /* CGSizeExtension.swift */; }; + 84F232D5254C831E0011862E /* model_opt.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 84F232D4254C831E0011862E /* model_opt.tflite */; }; + 84FCF5922387BD7900663812 /* tfl_logo.png in Resources */ = {isa = PBXBuildFile; fileRef = 84FCF5912387BD7900663812 /* tfl_logo.png */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 8402440023D9834600704ABD /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; + 840ECB1F238BAA2300C7D88A /* InfoCell.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InfoCell.swift; sourceTree = ""; }; + 840EDCFC2341DDD30017ED42 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = "Base.lproj/Launch Screen.storyboard"; sourceTree = ""; }; + 840EDD012341DE380017ED42 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/Main.storyboard; sourceTree = ""; }; + 842DDB6D2372A82000F6BB94 /* OverlayView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OverlayView.swift; sourceTree = ""; }; + 846499C1235DAB0D009CBBC7 /* ModelDataHandler.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelDataHandler.swift; sourceTree = ""; }; + 846BAF7523E7FE13006FC136 /* Constants.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Constants.swift; sourceTree = ""; }; + 8474FEC82341D36E00377D34 /* PreviewView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PreviewView.swift; sourceTree = ""; }; + 8474FECA2341D39800377D34 /* CameraFeedManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CameraFeedManager.swift; sourceTree = ""; }; + 84884291236FF0A30043FC4C /* download_models.sh */ = {isa = PBXFileReference; lastKnownFileType = text.script.sh; path = download_models.sh; sourceTree = ""; }; + 84952CB4236186BE0052C104 /* CVPixelBufferExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CVPixelBufferExtension.swift; sourceTree = ""; }; + 84952CB82361874A0052C104 /* TFLiteExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TFLiteExtension.swift; sourceTree = ""; }; + 84B67CEB2326338300A11A08 /* Midas.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = Midas.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 84B67CEE2326338300A11A08 /* AppDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppDelegate.swift; sourceTree = ""; }; + 84B67CF02326338300A11A08 /* ViewController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ViewController.swift; sourceTree = ""; }; + 84B67CF52326338400A11A08 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; + 84B67CFA2326338400A11A08 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; + 84D6576C2387BB7E0048171E /* CGSizeExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CGSizeExtension.swift; sourceTree = ""; }; + 84F232D4254C831E0011862E /* model_opt.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = model_opt.tflite; sourceTree = ""; }; + 84FCF5912387BD7900663812 /* tfl_logo.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; name = tfl_logo.png; path = Assets.xcassets/tfl_logo.png; sourceTree = ""; }; + A1CE41C09920CCEC31985547 /* Pods_Midas.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_Midas.framework; sourceTree = BUILT_PRODUCTS_DIR; }; + D2BFF06D0AE9137D332447F3 /* Pods-Midas.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-Midas.release.xcconfig"; path = "Target Support Files/Pods-Midas/Pods-Midas.release.xcconfig"; sourceTree = ""; }; + FCA88463911267B1001A596F /* Pods-Midas.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-Midas.debug.xcconfig"; path = "Target Support Files/Pods-Midas/Pods-Midas.debug.xcconfig"; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 84B67CE82326338300A11A08 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + 0CDA8C85042ADF65D0787629 /* Pods_Midas.framework in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 840ECB1E238BAA0D00C7D88A /* Cells */ = { + isa = PBXGroup; + children = ( + 840ECB1F238BAA2300C7D88A /* InfoCell.swift */, + ); + path = Cells; + sourceTree = ""; + }; + 842DDB6C2372A80E00F6BB94 /* Views */ = { + isa = PBXGroup; + children = ( + 842DDB6D2372A82000F6BB94 /* OverlayView.swift */, + ); + path = Views; + sourceTree = ""; + }; + 846499C0235DAAE7009CBBC7 /* ModelDataHandler */ = { + isa = PBXGroup; + children = ( + 846499C1235DAB0D009CBBC7 /* ModelDataHandler.swift */, + ); + path = ModelDataHandler; + sourceTree = ""; + }; + 8474FEC62341D2BE00377D34 /* ViewControllers */ = { + isa = PBXGroup; + children = ( + 84B67CF02326338300A11A08 /* ViewController.swift */, + ); + path = ViewControllers; + sourceTree = ""; + }; + 8474FEC72341D35800377D34 /* Camera Feed */ = { + isa = PBXGroup; + children = ( + 8474FEC82341D36E00377D34 /* PreviewView.swift */, + 8474FECA2341D39800377D34 /* CameraFeedManager.swift */, + ); + path = "Camera Feed"; + sourceTree = ""; + }; + 84884290236FF07F0043FC4C /* RunScripts */ = { + isa = PBXGroup; + children = ( + 84884291236FF0A30043FC4C /* download_models.sh */, + ); + path = RunScripts; + sourceTree = ""; + }; + 848842A22370180C0043FC4C /* Model */ = { + isa = PBXGroup; + children = ( + 84F232D4254C831E0011862E /* model_opt.tflite */, + ); + path = Model; + sourceTree = ""; + }; + 84952CB3236186A20052C104 /* Extensions */ = { + isa = PBXGroup; + children = ( + 84952CB4236186BE0052C104 /* CVPixelBufferExtension.swift */, + 84952CB82361874A0052C104 /* TFLiteExtension.swift */, + 84D6576C2387BB7E0048171E /* CGSizeExtension.swift */, + ); + path = Extensions; + sourceTree = ""; + }; + 84B67CE22326338300A11A08 = { + isa = PBXGroup; + children = ( + 8402440023D9834600704ABD /* README.md */, + 84884290236FF07F0043FC4C /* RunScripts */, + 84B67CED2326338300A11A08 /* Midas */, + 84B67CEC2326338300A11A08 /* Products */, + B4DFDCC28443B641BC36251D /* Pods */, + A3DA804B8D3F6891E3A02852 /* Frameworks */, + ); + sourceTree = ""; + }; + 84B67CEC2326338300A11A08 /* Products */ = { + isa = PBXGroup; + children = ( + 84B67CEB2326338300A11A08 /* Midas.app */, + ); + name = Products; + sourceTree = ""; + }; + 84B67CED2326338300A11A08 /* Midas */ = { + isa = PBXGroup; + children = ( + 840ECB1E238BAA0D00C7D88A /* Cells */, + 842DDB6C2372A80E00F6BB94 /* Views */, + 848842A22370180C0043FC4C /* Model */, + 84952CB3236186A20052C104 /* Extensions */, + 846499C0235DAAE7009CBBC7 /* ModelDataHandler */, + 8474FEC72341D35800377D34 /* Camera Feed */, + 8474FEC62341D2BE00377D34 /* ViewControllers */, + 84B67D002326339000A11A08 /* Storyboards */, + 84B67CEE2326338300A11A08 /* AppDelegate.swift */, + 846BAF7523E7FE13006FC136 /* Constants.swift */, + 84B67CF52326338400A11A08 /* Assets.xcassets */, + 84FCF5912387BD7900663812 /* tfl_logo.png */, + 84B67CFA2326338400A11A08 /* Info.plist */, + ); + path = Midas; + sourceTree = ""; + }; + 84B67D002326339000A11A08 /* Storyboards */ = { + isa = PBXGroup; + children = ( + 840EDCFB2341DDD30017ED42 /* Launch Screen.storyboard */, + 840EDD002341DE380017ED42 /* Main.storyboard */, + ); + path = Storyboards; + sourceTree = ""; + }; + A3DA804B8D3F6891E3A02852 /* Frameworks */ = { + isa = PBXGroup; + children = ( + A1CE41C09920CCEC31985547 /* Pods_Midas.framework */, + ); + name = Frameworks; + sourceTree = ""; + }; + B4DFDCC28443B641BC36251D /* Pods */ = { + isa = PBXGroup; + children = ( + FCA88463911267B1001A596F /* Pods-Midas.debug.xcconfig */, + D2BFF06D0AE9137D332447F3 /* Pods-Midas.release.xcconfig */, + ); + path = Pods; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 84B67CEA2326338300A11A08 /* Midas */ = { + isa = PBXNativeTarget; + buildConfigurationList = 84B67CFD2326338400A11A08 /* Build configuration list for PBXNativeTarget "Midas" */; + buildPhases = ( + 14067F3CF309C9DB723C9F6F /* [CP] Check Pods Manifest.lock */, + 84884298237010B90043FC4C /* Download TensorFlow Lite model */, + 84B67CE72326338300A11A08 /* Sources */, + 84B67CE82326338300A11A08 /* Frameworks */, + 84B67CE92326338300A11A08 /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = Midas; + productName = Midas; + productReference = 84B67CEB2326338300A11A08 /* Midas.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 84B67CE32326338300A11A08 /* Project object */ = { + isa = PBXProject; + attributes = { + LastSwiftUpdateCheck = 1030; + LastUpgradeCheck = 1030; + ORGANIZATIONNAME = tensorflow; + TargetAttributes = { + 84B67CEA2326338300A11A08 = { + CreatedOnToolsVersion = 10.3; + }; + }; + }; + buildConfigurationList = 84B67CE62326338300A11A08 /* Build configuration list for PBXProject "Midas" */; + compatibilityVersion = "Xcode 9.3"; + developmentRegion = en; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 84B67CE22326338300A11A08; + productRefGroup = 84B67CEC2326338300A11A08 /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 84B67CEA2326338300A11A08 /* Midas */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 84B67CE92326338300A11A08 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 8402440123D9834600704ABD /* README.md in Resources */, + 84F232D5254C831E0011862E /* model_opt.tflite in Resources */, + 840EDD022341DE380017ED42 /* Main.storyboard in Resources */, + 840EDCFD2341DDD30017ED42 /* Launch Screen.storyboard in Resources */, + 84FCF5922387BD7900663812 /* tfl_logo.png in Resources */, + 84B67CF62326338400A11A08 /* Assets.xcassets in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXShellScriptBuildPhase section */ + 14067F3CF309C9DB723C9F6F /* [CP] Check Pods Manifest.lock */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputFileListPaths = ( + ); + inputPaths = ( + "${PODS_PODFILE_DIR_PATH}/Podfile.lock", + "${PODS_ROOT}/Manifest.lock", + ); + name = "[CP] Check Pods Manifest.lock"; + outputFileListPaths = ( + ); + outputPaths = ( + "$(DERIVED_FILE_DIR)/Pods-Midas-checkManifestLockResult.txt", + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n"; + showEnvVarsInLog = 0; + }; + 84884298237010B90043FC4C /* Download TensorFlow Lite model */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputFileListPaths = ( + ); + inputPaths = ( + ); + name = "Download TensorFlow Lite model"; + outputFileListPaths = ( + ); + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/bash; + shellScript = "\"$SRCROOT/RunScripts/download_models.sh\"\n"; + }; +/* End PBXShellScriptBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 84B67CE72326338300A11A08 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 842DDB6E2372A82000F6BB94 /* OverlayView.swift in Sources */, + 846BAF7623E7FE13006FC136 /* Constants.swift in Sources */, + 84952CB92361874A0052C104 /* TFLiteExtension.swift in Sources */, + 84D6576D2387BB7E0048171E /* CGSizeExtension.swift in Sources */, + 84B67CF12326338300A11A08 /* ViewController.swift in Sources */, + 84B67CEF2326338300A11A08 /* AppDelegate.swift in Sources */, + 8474FECB2341D39800377D34 /* CameraFeedManager.swift in Sources */, + 846499C2235DAB0D009CBBC7 /* ModelDataHandler.swift in Sources */, + 8474FEC92341D36E00377D34 /* PreviewView.swift in Sources */, + 84952CB5236186BE0052C104 /* CVPixelBufferExtension.swift in Sources */, + 840ECB20238BAA2300C7D88A /* InfoCell.swift in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin PBXVariantGroup section */ + 840EDCFB2341DDD30017ED42 /* Launch Screen.storyboard */ = { + isa = PBXVariantGroup; + children = ( + 840EDCFC2341DDD30017ED42 /* Base */, + ); + name = "Launch Screen.storyboard"; + sourceTree = ""; + }; + 840EDD002341DE380017ED42 /* Main.storyboard */ = { + isa = PBXVariantGroup; + children = ( + 840EDD012341DE380017ED42 /* Base */, + ); + name = Main.storyboard; + sourceTree = ""; + }; +/* End PBXVariantGroup section */ + +/* Begin XCBuildConfiguration section */ + 84B67CFB2326338400A11A08 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 12.4; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + MTL_FAST_MATH = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + }; + name = Debug; + }; + 84B67CFC2326338400A11A08 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 12.4; + MTL_ENABLE_DEBUG_INFO = NO; + MTL_FAST_MATH = YES; + SDKROOT = iphoneos; + SWIFT_COMPILATION_MODE = wholemodule; + SWIFT_OPTIMIZATION_LEVEL = "-O"; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + 84B67CFE2326338400A11A08 /* Debug */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = FCA88463911267B1001A596F /* Pods-Midas.debug.xcconfig */; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_IDENTITY = "iPhone Developer"; + CODE_SIGN_STYLE = Automatic; + DEVELOPMENT_TEAM = BV6M48J3RX; + INFOPLIST_FILE = Midas/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + PRODUCT_BUNDLE_IDENTIFIER = "org.midas.midas-tflite-npu"; + PRODUCT_NAME = Midas; + PROVISIONING_PROFILE_SPECIFIER = ""; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + 84B67CFF2326338400A11A08 /* Release */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = D2BFF06D0AE9137D332447F3 /* Pods-Midas.release.xcconfig */; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_IDENTITY = "iPhone Developer"; + CODE_SIGN_STYLE = Automatic; + DEVELOPMENT_TEAM = BV6M48J3RX; + INFOPLIST_FILE = Midas/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + PRODUCT_BUNDLE_IDENTIFIER = "org.midas.midas-tflite-npu"; + PRODUCT_NAME = Midas; + PROVISIONING_PROFILE_SPECIFIER = ""; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 84B67CE62326338300A11A08 /* Build configuration list for PBXProject "Midas" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 84B67CFB2326338400A11A08 /* Debug */, + 84B67CFC2326338400A11A08 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 84B67CFD2326338400A11A08 /* Build configuration list for PBXNativeTarget "Midas" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 84B67CFE2326338400A11A08 /* Debug */, + 84B67CFF2326338400A11A08 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 84B67CE32326338300A11A08 /* Project object */; +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/contents.xcworkspacedata b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/contents.xcworkspacedata new file mode 100644 index 0000000000000000000000000000000000000000..919434a6254f0e9651f402737811be6634a03e9c --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/contents.xcworkspacedata @@ -0,0 +1,7 @@ + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist new file mode 100644 index 0000000000000000000000000000000000000000..18d981003d68d0546c4804ac2ff47dd97c6e7921 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist @@ -0,0 +1,8 @@ + + + + + IDEDidComputeMac32BitWarning + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcuserdata/admin.xcuserdatad/UserInterfaceState.xcuserstate b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcuserdata/admin.xcuserdatad/UserInterfaceState.xcuserstate new file mode 100644 index 0000000000000000000000000000000000000000..1d20756ee57b79e9f9f886453bdb7997ca2ee2d4 Binary files /dev/null and b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcuserdata/admin.xcuserdatad/UserInterfaceState.xcuserstate differ diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/xcuserdata/admin.xcuserdatad/xcschemes/xcschememanagement.plist b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/xcuserdata/admin.xcuserdatad/xcschemes/xcschememanagement.plist new file mode 100644 index 0000000000000000000000000000000000000000..6093f6160eedfdfc20e96396247a7dbc9247cc55 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/xcuserdata/admin.xcuserdatad/xcschemes/xcschememanagement.plist @@ -0,0 +1,14 @@ + + + + + SchemeUserState + + PoseNet.xcscheme_^#shared#^_ + + orderHint + 3 + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/AppDelegate.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/AppDelegate.swift new file mode 100644 index 0000000000000000000000000000000000000000..233f0291ab4f379067543bdad3cc198a2dc3ab0f --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/AppDelegate.swift @@ -0,0 +1,41 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import UIKit + +@UIApplicationMain +class AppDelegate: UIResponder, UIApplicationDelegate { + + var window: UIWindow? + + func application(_ application: UIApplication, didFinishLaunchingWithOptions launchOptions: [UIApplication.LaunchOptionsKey: Any]?) -> Bool { + return true + } + + func applicationWillResignActive(_ application: UIApplication) { + } + + func applicationDidEnterBackground(_ application: UIApplication) { + } + + func applicationWillEnterForeground(_ application: UIApplication) { + } + + func applicationDidBecomeActive(_ application: UIApplication) { + } + + func applicationWillTerminate(_ application: UIApplication) { + } +} + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/Contents.json b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..65b74d7ef11fa59fafa829e681ac90906f3ac8b2 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1 @@ +{"images":[{"size":"60x60","expected-size":"180","filename":"180.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"40x40","expected-size":"80","filename":"80.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"40x40","expected-size":"120","filename":"120.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"60x60","expected-size":"120","filename":"120.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"57x57","expected-size":"57","filename":"57.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"1x"},{"size":"29x29","expected-size":"58","filename":"58.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"29x29","expected-size":"29","filename":"29.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"1x"},{"size":"29x29","expected-size":"87","filename":"87.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"57x57","expected-size":"114","filename":"114.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"20x20","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"20x20","expected-size":"60","filename":"60.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"1024x1024","filename":"1024.png","expected-size":"1024","idiom":"ios-marketing","folder":"Assets.xcassets/AppIcon.appiconset/","scale":"1x"},{"size":"40x40","expected-size":"80","filename":"80.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"72x72","expected-size":"72","filename":"72.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"76x76","expected-size":"152","filename":"152.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"50x50","expected-size":"100","filename":"100.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"29x29","expected-size":"58","filename":"58.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"76x76","expected-size":"76","filename":"76.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"29x29","expected-size":"29","filename":"29.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"50x50","expected-size":"50","filename":"50.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"72x72","expected-size":"144","filename":"144.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"40x40","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"83.5x83.5","expected-size":"167","filename":"167.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"20x20","expected-size":"20","filename":"20.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"20x20","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"}]} \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/Contents.json b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..da4a164c918651cdd1e11dca5cc62c333f097601 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "version" : 1, + "author" : "xcode" + } +} \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/CameraFeedManager.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/CameraFeedManager.swift new file mode 100644 index 0000000000000000000000000000000000000000..48d65b88ee220e722fbad2570c8e879a431cd0f5 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/CameraFeedManager.swift @@ -0,0 +1,316 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import AVFoundation +import UIKit +import os + +// MARK: - CameraFeedManagerDelegate Declaration +@objc protocol CameraFeedManagerDelegate: class { + /// This method delivers the pixel buffer of the current frame seen by the device's camera. + @objc optional func cameraFeedManager( + _ manager: CameraFeedManager, didOutput pixelBuffer: CVPixelBuffer + ) + + /// This method initimates that a session runtime error occured. + func cameraFeedManagerDidEncounterSessionRunTimeError(_ manager: CameraFeedManager) + + /// This method initimates that the session was interrupted. + func cameraFeedManager( + _ manager: CameraFeedManager, sessionWasInterrupted canResumeManually: Bool + ) + + /// This method initimates that the session interruption has ended. + func cameraFeedManagerDidEndSessionInterruption(_ manager: CameraFeedManager) + + /// This method initimates that there was an error in video configurtion. + func presentVideoConfigurationErrorAlert(_ manager: CameraFeedManager) + + /// This method initimates that the camera permissions have been denied. + func presentCameraPermissionsDeniedAlert(_ manager: CameraFeedManager) +} + +/// This enum holds the state of the camera initialization. +// MARK: - Camera Initialization State Enum +enum CameraConfiguration { + case success + case failed + case permissionDenied +} + +/// This class manages all camera related functionalities. +// MARK: - Camera Related Functionalies Manager +class CameraFeedManager: NSObject { + // MARK: Camera Related Instance Variables + private let session: AVCaptureSession = AVCaptureSession() + + private let previewView: PreviewView + private let sessionQueue = DispatchQueue(label: "sessionQueue") + private var cameraConfiguration: CameraConfiguration = .failed + private lazy var videoDataOutput = AVCaptureVideoDataOutput() + private var isSessionRunning = false + + // MARK: CameraFeedManagerDelegate + weak var delegate: CameraFeedManagerDelegate? + + // MARK: Initializer + init(previewView: PreviewView) { + self.previewView = previewView + super.init() + + // Initializes the session + session.sessionPreset = .high + self.previewView.session = session + self.previewView.previewLayer.connection?.videoOrientation = .portrait + self.previewView.previewLayer.videoGravity = .resizeAspectFill + self.attemptToConfigureSession() + } + + // MARK: Session Start and End methods + + /// This method starts an AVCaptureSession based on whether the camera configuration was successful. + func checkCameraConfigurationAndStartSession() { + sessionQueue.async { + switch self.cameraConfiguration { + case .success: + self.addObservers() + self.startSession() + case .failed: + DispatchQueue.main.async { + self.delegate?.presentVideoConfigurationErrorAlert(self) + } + case .permissionDenied: + DispatchQueue.main.async { + self.delegate?.presentCameraPermissionsDeniedAlert(self) + } + } + } + } + + /// This method stops a running an AVCaptureSession. + func stopSession() { + self.removeObservers() + sessionQueue.async { + if self.session.isRunning { + self.session.stopRunning() + self.isSessionRunning = self.session.isRunning + } + } + + } + + /// This method resumes an interrupted AVCaptureSession. + func resumeInterruptedSession(withCompletion completion: @escaping (Bool) -> Void) { + sessionQueue.async { + self.startSession() + + DispatchQueue.main.async { + completion(self.isSessionRunning) + } + } + } + + /// This method starts the AVCaptureSession + private func startSession() { + self.session.startRunning() + self.isSessionRunning = self.session.isRunning + } + + // MARK: Session Configuration Methods. + /// This method requests for camera permissions and handles the configuration of the session and stores the result of configuration. + private func attemptToConfigureSession() { + switch AVCaptureDevice.authorizationStatus(for: .video) { + case .authorized: + self.cameraConfiguration = .success + case .notDetermined: + self.sessionQueue.suspend() + self.requestCameraAccess(completion: { granted in + self.sessionQueue.resume() + }) + case .denied: + self.cameraConfiguration = .permissionDenied + default: + break + } + + self.sessionQueue.async { + self.configureSession() + } + } + + /// This method requests for camera permissions. + private func requestCameraAccess(completion: @escaping (Bool) -> Void) { + AVCaptureDevice.requestAccess(for: .video) { (granted) in + if !granted { + self.cameraConfiguration = .permissionDenied + } else { + self.cameraConfiguration = .success + } + completion(granted) + } + } + + /// This method handles all the steps to configure an AVCaptureSession. + private func configureSession() { + guard cameraConfiguration == .success else { + return + } + session.beginConfiguration() + + // Tries to add an AVCaptureDeviceInput. + guard addVideoDeviceInput() == true else { + self.session.commitConfiguration() + self.cameraConfiguration = .failed + return + } + + // Tries to add an AVCaptureVideoDataOutput. + guard addVideoDataOutput() else { + self.session.commitConfiguration() + self.cameraConfiguration = .failed + return + } + + session.commitConfiguration() + self.cameraConfiguration = .success + } + + /// This method tries to an AVCaptureDeviceInput to the current AVCaptureSession. + private func addVideoDeviceInput() -> Bool { + /// Tries to get the default back camera. + guard + let camera = AVCaptureDevice.default(.builtInWideAngleCamera, for: .video, position: .back) + else { + fatalError("Cannot find camera") + } + + do { + let videoDeviceInput = try AVCaptureDeviceInput(device: camera) + if session.canAddInput(videoDeviceInput) { + session.addInput(videoDeviceInput) + return true + } else { + return false + } + } catch { + fatalError("Cannot create video device input") + } + } + + /// This method tries to an AVCaptureVideoDataOutput to the current AVCaptureSession. + private func addVideoDataOutput() -> Bool { + let sampleBufferQueue = DispatchQueue(label: "sampleBufferQueue") + videoDataOutput.setSampleBufferDelegate(self, queue: sampleBufferQueue) + videoDataOutput.alwaysDiscardsLateVideoFrames = true + videoDataOutput.videoSettings = [ + String(kCVPixelBufferPixelFormatTypeKey): kCMPixelFormat_32BGRA + ] + + if session.canAddOutput(videoDataOutput) { + session.addOutput(videoDataOutput) + videoDataOutput.connection(with: .video)?.videoOrientation = .portrait + return true + } + return false + } + + // MARK: Notification Observer Handling + private func addObservers() { + NotificationCenter.default.addObserver( + self, selector: #selector(CameraFeedManager.sessionRuntimeErrorOccured(notification:)), + name: NSNotification.Name.AVCaptureSessionRuntimeError, object: session) + NotificationCenter.default.addObserver( + self, selector: #selector(CameraFeedManager.sessionWasInterrupted(notification:)), + name: NSNotification.Name.AVCaptureSessionWasInterrupted, object: session) + NotificationCenter.default.addObserver( + self, selector: #selector(CameraFeedManager.sessionInterruptionEnded), + name: NSNotification.Name.AVCaptureSessionInterruptionEnded, object: session) + } + + private func removeObservers() { + NotificationCenter.default.removeObserver( + self, name: NSNotification.Name.AVCaptureSessionRuntimeError, object: session) + NotificationCenter.default.removeObserver( + self, name: NSNotification.Name.AVCaptureSessionWasInterrupted, object: session) + NotificationCenter.default.removeObserver( + self, name: NSNotification.Name.AVCaptureSessionInterruptionEnded, object: session) + } + + // MARK: Notification Observers + @objc func sessionWasInterrupted(notification: Notification) { + if let userInfoValue = notification.userInfo?[AVCaptureSessionInterruptionReasonKey] + as AnyObject?, + let reasonIntegerValue = userInfoValue.integerValue, + let reason = AVCaptureSession.InterruptionReason(rawValue: reasonIntegerValue) + { + os_log("Capture session was interrupted with reason: %s", type: .error, reason.rawValue) + + var canResumeManually = false + if reason == .videoDeviceInUseByAnotherClient { + canResumeManually = true + } else if reason == .videoDeviceNotAvailableWithMultipleForegroundApps { + canResumeManually = false + } + + delegate?.cameraFeedManager(self, sessionWasInterrupted: canResumeManually) + + } + } + + @objc func sessionInterruptionEnded(notification: Notification) { + delegate?.cameraFeedManagerDidEndSessionInterruption(self) + } + + @objc func sessionRuntimeErrorOccured(notification: Notification) { + guard let error = notification.userInfo?[AVCaptureSessionErrorKey] as? AVError else { + return + } + + os_log("Capture session runtime error: %s", type: .error, error.localizedDescription) + + if error.code == .mediaServicesWereReset { + sessionQueue.async { + if self.isSessionRunning { + self.startSession() + } else { + DispatchQueue.main.async { + self.delegate?.cameraFeedManagerDidEncounterSessionRunTimeError(self) + } + } + } + } else { + delegate?.cameraFeedManagerDidEncounterSessionRunTimeError(self) + } + } +} + +/// AVCaptureVideoDataOutputSampleBufferDelegate +extension CameraFeedManager: AVCaptureVideoDataOutputSampleBufferDelegate { + /// This method delegates the CVPixelBuffer of the frame seen by the camera currently. + func captureOutput( + _ output: AVCaptureOutput, didOutput sampleBuffer: CMSampleBuffer, + from connection: AVCaptureConnection + ) { + + // Converts the CMSampleBuffer to a CVPixelBuffer. + let pixelBuffer: CVPixelBuffer? = CMSampleBufferGetImageBuffer(sampleBuffer) + + guard let imagePixelBuffer = pixelBuffer else { + return + } + + // Delegates the pixel buffer to the ViewController. + delegate?.cameraFeedManager?(self, didOutput: imagePixelBuffer) + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/PreviewView.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/PreviewView.swift new file mode 100644 index 0000000000000000000000000000000000000000..308c7ec54308af5c152ff6038670b26501a8e82c --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/PreviewView.swift @@ -0,0 +1,39 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import UIKit +import AVFoundation + + /// The camera frame is displayed on this view. +class PreviewView: UIView { + var previewLayer: AVCaptureVideoPreviewLayer { + guard let layer = layer as? AVCaptureVideoPreviewLayer else { + fatalError("Layer expected is of type VideoPreviewLayer") + } + return layer + } + + var session: AVCaptureSession? { + get { + return previewLayer.session + } + set { + previewLayer.session = newValue + } + } + + override class var layerClass: AnyClass { + return AVCaptureVideoPreviewLayer.self + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Cells/InfoCell.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Cells/InfoCell.swift new file mode 100644 index 0000000000000000000000000000000000000000..c6be64af5678541ec09fc367b03c80155876f0ba --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Cells/InfoCell.swift @@ -0,0 +1,21 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import UIKit + +/// Table cell for inference result in bottom view. +class InfoCell: UITableViewCell { + @IBOutlet weak var fieldNameLabel: UILabel! + @IBOutlet weak var infoLabel: UILabel! +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Constants.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Constants.swift new file mode 100644 index 0000000000000000000000000000000000000000..b0789ee58a1ea373d441f05333d8ce8914adadb7 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Constants.swift @@ -0,0 +1,25 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +enum Constants { + // MARK: - Constants related to the image processing + static let bgraPixel = (channels: 4, alphaComponent: 3, lastBgrComponent: 2) + static let rgbPixelChannels = 3 + static let maxRGBValue: Float32 = 255.0 + + // MARK: - Constants related to the model interperter + static let defaultThreadCount = 2 + static let defaultDelegate: Delegates = .CPU +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CGSizeExtension.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CGSizeExtension.swift new file mode 100644 index 0000000000000000000000000000000000000000..031550ea0081963d18b5b83712854babaf7c0a34 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CGSizeExtension.swift @@ -0,0 +1,45 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +import Accelerate +import Foundation + +extension CGSize { + /// Returns `CGAfineTransform` to resize `self` to fit in destination size, keeping aspect ratio + /// of `self`. `self` image is resized to be inscribe to destination size and located in center of + /// destination. + /// + /// - Parameter toFitIn: destination size to be filled. + /// - Returns: `CGAffineTransform` to transform `self` image to `dest` image. + func transformKeepAspect(toFitIn dest: CGSize) -> CGAffineTransform { + let sourceRatio = self.height / self.width + let destRatio = dest.height / dest.width + + // Calculates ratio `self` to `dest`. + var ratio: CGFloat + var x: CGFloat = 0 + var y: CGFloat = 0 + if sourceRatio > destRatio { + // Source size is taller than destination. Resized to fit in destination height, and find + // horizontal starting point to be centered. + ratio = dest.height / self.height + x = (dest.width - self.width * ratio) / 2 + } else { + ratio = dest.width / self.width + y = (dest.height - self.height * ratio) / 2 + } + return CGAffineTransform(a: ratio, b: 0, c: 0, d: ratio, tx: x, ty: y) + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CVPixelBufferExtension.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CVPixelBufferExtension.swift new file mode 100644 index 0000000000000000000000000000000000000000..4899c76562a546c513736fbf4556629b08d2c929 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CVPixelBufferExtension.swift @@ -0,0 +1,172 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +import Accelerate +import Foundation + +extension CVPixelBuffer { + var size: CGSize { + return CGSize(width: CVPixelBufferGetWidth(self), height: CVPixelBufferGetHeight(self)) + } + + /// Returns a new `CVPixelBuffer` created by taking the self area and resizing it to the + /// specified target size. Aspect ratios of source image and destination image are expected to be + /// same. + /// + /// - Parameters: + /// - from: Source area of image to be cropped and resized. + /// - to: Size to scale the image to(i.e. image size used while training the model). + /// - Returns: The cropped and resized image of itself. + func resize(from source: CGRect, to size: CGSize) -> CVPixelBuffer? { + let rect = CGRect(origin: CGPoint(x: 0, y: 0), size: self.size) + guard rect.contains(source) else { + os_log("Resizing Error: source area is out of index", type: .error) + return nil + } + guard rect.size.width / rect.size.height - source.size.width / source.size.height < 1e-5 + else { + os_log( + "Resizing Error: source image ratio and destination image ratio is different", + type: .error) + return nil + } + + let inputImageRowBytes = CVPixelBufferGetBytesPerRow(self) + let imageChannels = 4 + + CVPixelBufferLockBaseAddress(self, CVPixelBufferLockFlags(rawValue: 0)) + defer { CVPixelBufferUnlockBaseAddress(self, CVPixelBufferLockFlags(rawValue: 0)) } + + // Finds the address of the upper leftmost pixel of the source area. + guard + let inputBaseAddress = CVPixelBufferGetBaseAddress(self)?.advanced( + by: Int(source.minY) * inputImageRowBytes + Int(source.minX) * imageChannels) + else { + return nil + } + + // Crops given area as vImage Buffer. + var croppedImage = vImage_Buffer( + data: inputBaseAddress, height: UInt(source.height), width: UInt(source.width), + rowBytes: inputImageRowBytes) + + let resultRowBytes = Int(size.width) * imageChannels + guard let resultAddress = malloc(Int(size.height) * resultRowBytes) else { + return nil + } + + // Allocates a vacant vImage buffer for resized image. + var resizedImage = vImage_Buffer( + data: resultAddress, + height: UInt(size.height), width: UInt(size.width), + rowBytes: resultRowBytes + ) + + // Performs the scale operation on cropped image and stores it in result image buffer. + guard vImageScale_ARGB8888(&croppedImage, &resizedImage, nil, vImage_Flags(0)) == kvImageNoError + else { + return nil + } + + let releaseCallBack: CVPixelBufferReleaseBytesCallback = { mutablePointer, pointer in + if let pointer = pointer { + free(UnsafeMutableRawPointer(mutating: pointer)) + } + } + + var result: CVPixelBuffer? + + // Converts the thumbnail vImage buffer to CVPixelBuffer + let conversionStatus = CVPixelBufferCreateWithBytes( + nil, + Int(size.width), Int(size.height), + CVPixelBufferGetPixelFormatType(self), + resultAddress, + resultRowBytes, + releaseCallBack, + nil, + nil, + &result + ) + + guard conversionStatus == kCVReturnSuccess else { + free(resultAddress) + return nil + } + + return result + } + + /// Returns the RGB `Data` representation of the given image buffer. + /// + /// - Parameters: + /// - isModelQuantized: Whether the model is quantized (i.e. fixed point values rather than + /// floating point values). + /// - Returns: The RGB data representation of the image buffer or `nil` if the buffer could not be + /// converted. + func rgbData( + isModelQuantized: Bool + ) -> Data? { + CVPixelBufferLockBaseAddress(self, .readOnly) + defer { CVPixelBufferUnlockBaseAddress(self, .readOnly) } + guard let sourceData = CVPixelBufferGetBaseAddress(self) else { + return nil + } + + let width = CVPixelBufferGetWidth(self) + let height = CVPixelBufferGetHeight(self) + let sourceBytesPerRow = CVPixelBufferGetBytesPerRow(self) + let destinationBytesPerRow = Constants.rgbPixelChannels * width + + // Assign input image to `sourceBuffer` to convert it. + var sourceBuffer = vImage_Buffer( + data: sourceData, + height: vImagePixelCount(height), + width: vImagePixelCount(width), + rowBytes: sourceBytesPerRow) + + // Make `destinationBuffer` and `destinationData` for its data to be assigned. + guard let destinationData = malloc(height * destinationBytesPerRow) else { + os_log("Error: out of memory", type: .error) + return nil + } + defer { free(destinationData) } + var destinationBuffer = vImage_Buffer( + data: destinationData, + height: vImagePixelCount(height), + width: vImagePixelCount(width), + rowBytes: destinationBytesPerRow) + + // Convert image type. + switch CVPixelBufferGetPixelFormatType(self) { + case kCVPixelFormatType_32BGRA: + vImageConvert_BGRA8888toRGB888(&sourceBuffer, &destinationBuffer, UInt32(kvImageNoFlags)) + case kCVPixelFormatType_32ARGB: + vImageConvert_BGRA8888toRGB888(&sourceBuffer, &destinationBuffer, UInt32(kvImageNoFlags)) + default: + os_log("The type of this image is not supported.", type: .error) + return nil + } + + // Make `Data` with converted image. + let imageByteData = Data( + bytes: destinationBuffer.data, count: destinationBuffer.rowBytes * height) + + if isModelQuantized { return imageByteData } + + let imageBytes = [UInt8](imageByteData) + return Data(copyingBufferOf: imageBytes.map { Float($0) / Constants.maxRGBValue }) + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/TFLiteExtension.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/TFLiteExtension.swift new file mode 100644 index 0000000000000000000000000000000000000000..63f7ced786e2b550391c77af534d1d3c431522c6 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/TFLiteExtension.swift @@ -0,0 +1,75 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +import Accelerate +import CoreImage +import Foundation +import TensorFlowLite + +// MARK: - Data +extension Data { + /// Creates a new buffer by copying the buffer pointer of the given array. + /// + /// - Warning: The given array's element type `T` must be trivial in that it can be copied bit + /// for bit with no indirection or reference-counting operations; otherwise, reinterpreting + /// data from the resulting buffer has undefined behavior. + /// - Parameter array: An array with elements of type `T`. + init(copyingBufferOf array: [T]) { + self = array.withUnsafeBufferPointer(Data.init) + } + + /// Convert a Data instance to Array representation. + func toArray(type: T.Type) -> [T] where T: AdditiveArithmetic { + var array = [T](repeating: T.zero, count: self.count / MemoryLayout.stride) + _ = array.withUnsafeMutableBytes { self.copyBytes(to: $0) } + return array + } +} + +// MARK: - Wrappers +/// Struct for handling multidimension `Data` in flat `Array`. +struct FlatArray { + private var array: [Element] + var dimensions: [Int] + + init(tensor: Tensor) { + dimensions = tensor.shape.dimensions + array = tensor.data.toArray(type: Element.self) + } + + private func flatIndex(_ index: [Int]) -> Int { + guard index.count == dimensions.count else { + fatalError("Invalid index: got \(index.count) index(es) for \(dimensions.count) index(es).") + } + + var result = 0 + for i in 0.. index[i] else { + fatalError("Invalid index: \(index[i]) is bigger than \(dimensions[i])") + } + result = dimensions[i] * result + index[i] + } + return result + } + + subscript(_ index: Int...) -> Element { + get { + return array[flatIndex(index)] + } + set(newValue) { + array[flatIndex(index)] = newValue + } + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Info.plist b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Info.plist new file mode 100644 index 0000000000000000000000000000000000000000..4330d9b33f31010549802febc6f6f2bc9fd9b950 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Info.plist @@ -0,0 +1,42 @@ + + + + + CFBundleDevelopmentRegion + $(DEVELOPMENT_LANGUAGE) + CFBundleExecutable + $(EXECUTABLE_NAME) + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + $(PRODUCT_NAME) + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleVersion + 1 + LSRequiresIPhoneOS + + NSCameraUsageDescription + This app will use camera to continuously estimate the depth map. + UILaunchStoryboardName + LaunchScreen + UIMainStoryboardFile + Main + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ModelDataHandler/ModelDataHandler.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ModelDataHandler/ModelDataHandler.swift new file mode 100644 index 0000000000000000000000000000000000000000..144cfe1fa3a65af5adcb572237f2bf9718e570ae --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ModelDataHandler/ModelDataHandler.swift @@ -0,0 +1,464 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Accelerate +import CoreImage +import Foundation +import TensorFlowLite +import UIKit + +/// This class handles all data preprocessing and makes calls to run inference on a given frame +/// by invoking the `Interpreter`. It then formats the inferences obtained. +class ModelDataHandler { + // MARK: - Private Properties + + /// TensorFlow Lite `Interpreter` object for performing inference on a given model. + private var interpreter: Interpreter + + /// TensorFlow lite `Tensor` of model input and output. + private var inputTensor: Tensor + + //private var heatsTensor: Tensor + //private var offsetsTensor: Tensor + private var outputTensor: Tensor + // MARK: - Initialization + + /// A failable initializer for `ModelDataHandler`. A new instance is created if the model is + /// successfully loaded from the app's main bundle. Default `threadCount` is 2. + init( + threadCount: Int = Constants.defaultThreadCount, + delegate: Delegates = Constants.defaultDelegate + ) throws { + // Construct the path to the model file. + guard + let modelPath = Bundle.main.path( + forResource: Model.file.name, + ofType: Model.file.extension + ) + else { + fatalError("Failed to load the model file with name: \(Model.file.name).") + } + + // Specify the options for the `Interpreter`. + var options = Interpreter.Options() + options.threadCount = threadCount + + // Specify the delegates for the `Interpreter`. + var delegates: [Delegate]? + switch delegate { + case .Metal: + delegates = [MetalDelegate()] + case .CoreML: + if let coreMLDelegate = CoreMLDelegate() { + delegates = [coreMLDelegate] + } else { + delegates = nil + } + default: + delegates = nil + } + + // Create the `Interpreter`. + interpreter = try Interpreter(modelPath: modelPath, options: options, delegates: delegates) + + // Initialize input and output `Tensor`s. + // Allocate memory for the model's input `Tensor`s. + try interpreter.allocateTensors() + + // Get allocated input and output `Tensor`s. + inputTensor = try interpreter.input(at: 0) + outputTensor = try interpreter.output(at: 0) + //heatsTensor = try interpreter.output(at: 0) + //offsetsTensor = try interpreter.output(at: 1) + + /* + // Check if input and output `Tensor`s are in the expected formats. + guard (inputTensor.dataType == .uInt8) == Model.isQuantized else { + fatalError("Unexpected Model: quantization is \(!Model.isQuantized)") + } + + guard inputTensor.shape.dimensions[0] == Model.input.batchSize, + inputTensor.shape.dimensions[1] == Model.input.height, + inputTensor.shape.dimensions[2] == Model.input.width, + inputTensor.shape.dimensions[3] == Model.input.channelSize + else { + fatalError("Unexpected Model: input shape") + } + + + guard heatsTensor.shape.dimensions[0] == Model.output.batchSize, + heatsTensor.shape.dimensions[1] == Model.output.height, + heatsTensor.shape.dimensions[2] == Model.output.width, + heatsTensor.shape.dimensions[3] == Model.output.keypointSize + else { + fatalError("Unexpected Model: heat tensor") + } + + guard offsetsTensor.shape.dimensions[0] == Model.output.batchSize, + offsetsTensor.shape.dimensions[1] == Model.output.height, + offsetsTensor.shape.dimensions[2] == Model.output.width, + offsetsTensor.shape.dimensions[3] == Model.output.offsetSize + else { + fatalError("Unexpected Model: offset tensor") + } + */ + + } + + /// Runs Midas model with given image with given source area to destination area. + /// + /// - Parameters: + /// - on: Input image to run the model. + /// - from: Range of input image to run the model. + /// - to: Size of view to render the result. + /// - Returns: Result of the inference and the times consumed in every steps. + func runMidas(on pixelbuffer: CVPixelBuffer, from source: CGRect, to dest: CGSize) + //-> (Result, Times)? + //-> (FlatArray, Times)? + -> ([Float], Int, Int, Times)? + { + // Start times of each process. + let preprocessingStartTime: Date + let inferenceStartTime: Date + let postprocessingStartTime: Date + + // Processing times in miliseconds. + let preprocessingTime: TimeInterval + let inferenceTime: TimeInterval + let postprocessingTime: TimeInterval + + preprocessingStartTime = Date() + guard let data = preprocess(of: pixelbuffer, from: source) else { + os_log("Preprocessing failed", type: .error) + return nil + } + preprocessingTime = Date().timeIntervalSince(preprocessingStartTime) * 1000 + + inferenceStartTime = Date() + inference(from: data) + inferenceTime = Date().timeIntervalSince(inferenceStartTime) * 1000 + + postprocessingStartTime = Date() + //guard let result = postprocess(to: dest) else { + // os_log("Postprocessing failed", type: .error) + // return nil + //} + postprocessingTime = Date().timeIntervalSince(postprocessingStartTime) * 1000 + + + let results: [Float] + switch outputTensor.dataType { + case .uInt8: + guard let quantization = outputTensor.quantizationParameters else { + print("No results returned because the quantization values for the output tensor are nil.") + return nil + } + let quantizedResults = [UInt8](outputTensor.data) + results = quantizedResults.map { + quantization.scale * Float(Int($0) - quantization.zeroPoint) + } + case .float32: + results = [Float32](unsafeData: outputTensor.data) ?? [] + default: + print("Output tensor data type \(outputTensor.dataType) is unsupported for this example app.") + return nil + } + + + let times = Times( + preprocessing: preprocessingTime, + inference: inferenceTime, + postprocessing: postprocessingTime) + + return (results, Model.input.width, Model.input.height, times) + } + + // MARK: - Private functions to run model + /// Preprocesses given rectangle image to be `Data` of disired size by croping and resizing it. + /// + /// - Parameters: + /// - of: Input image to crop and resize. + /// - from: Target area to be cropped and resized. + /// - Returns: The cropped and resized image. `nil` if it can not be processed. + private func preprocess(of pixelBuffer: CVPixelBuffer, from targetSquare: CGRect) -> Data? { + let sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer) + assert(sourcePixelFormat == kCVPixelFormatType_32BGRA) + + // Resize `targetSquare` of input image to `modelSize`. + let modelSize = CGSize(width: Model.input.width, height: Model.input.height) + guard let thumbnail = pixelBuffer.resize(from: targetSquare, to: modelSize) + else { + return nil + } + + // Remove the alpha component from the image buffer to get the initialized `Data`. + let byteCount = + Model.input.batchSize + * Model.input.height * Model.input.width + * Model.input.channelSize + guard + let inputData = thumbnail.rgbData( + isModelQuantized: Model.isQuantized + ) + else { + os_log("Failed to convert the image buffer to RGB data.", type: .error) + return nil + } + + return inputData + } + + + + /* + /// Postprocesses output `Tensor`s to `Result` with size of view to render the result. + /// + /// - Parameters: + /// - to: Size of view to be displaied. + /// - Returns: Postprocessed `Result`. `nil` if it can not be processed. + private func postprocess(to viewSize: CGSize) -> Result? { + // MARK: Formats output tensors + // Convert `Tensor` to `FlatArray`. As Midas is not quantized, convert them to Float type + // `FlatArray`. + let heats = FlatArray(tensor: heatsTensor) + let offsets = FlatArray(tensor: offsetsTensor) + + // MARK: Find position of each key point + // Finds the (row, col) locations of where the keypoints are most likely to be. The highest + // `heats[0, row, col, keypoint]` value, the more likely `keypoint` being located in (`row`, + // `col`). + let keypointPositions = (0.. (Int, Int) in + var maxValue = heats[0, 0, 0, keypoint] + var maxRow = 0 + var maxCol = 0 + for row in 0.. maxValue { + maxValue = heats[0, row, col, keypoint] + maxRow = row + maxCol = col + } + } + } + return (maxRow, maxCol) + } + + // MARK: Calculates total confidence score + // Calculates total confidence score of each key position. + let totalScoreSum = keypointPositions.enumerated().reduce(0.0) { accumulator, elem -> Float32 in + accumulator + sigmoid(heats[0, elem.element.0, elem.element.1, elem.offset]) + } + let totalScore = totalScoreSum / Float32(Model.output.keypointSize) + + // MARK: Calculate key point position on model input + // Calculates `KeyPoint` coordination model input image with `offsets` adjustment. + let coords = keypointPositions.enumerated().map { index, elem -> (y: Float32, x: Float32) in + let (y, x) = elem + let yCoord = + Float32(y) / Float32(Model.output.height - 1) * Float32(Model.input.height) + + offsets[0, y, x, index] + let xCoord = + Float32(x) / Float32(Model.output.width - 1) * Float32(Model.input.width) + + offsets[0, y, x, index + Model.output.keypointSize] + return (y: yCoord, x: xCoord) + } + + // MARK: Transform key point position and make lines + // Make `Result` from `keypointPosition'. Each point is adjusted to `ViewSize` to be drawn. + var result = Result(dots: [], lines: [], score: totalScore) + var bodyPartToDotMap = [BodyPart: CGPoint]() + for (index, part) in BodyPart.allCases.enumerated() { + let position = CGPoint( + x: CGFloat(coords[index].x) * viewSize.width / CGFloat(Model.input.width), + y: CGFloat(coords[index].y) * viewSize.height / CGFloat(Model.input.height) + ) + bodyPartToDotMap[part] = position + result.dots.append(position) + } + + do { + try result.lines = BodyPart.lines.map { map throws -> Line in + guard let from = bodyPartToDotMap[map.from] else { + throw PostprocessError.missingBodyPart(of: map.from) + } + guard let to = bodyPartToDotMap[map.to] else { + throw PostprocessError.missingBodyPart(of: map.to) + } + return Line(from: from, to: to) + } + } catch PostprocessError.missingBodyPart(let missingPart) { + os_log("Postprocessing error: %s is missing.", type: .error, missingPart.rawValue) + return nil + } catch { + os_log("Postprocessing error: %s", type: .error, error.localizedDescription) + return nil + } + + return result + } +*/ + + + + /// Run inference with given `Data` + /// + /// Parameter `from`: `Data` of input image to run model. + private func inference(from data: Data) { + // Copy the initialized `Data` to the input `Tensor`. + do { + try interpreter.copy(data, toInputAt: 0) + + // Run inference by invoking the `Interpreter`. + try interpreter.invoke() + + // Get the output `Tensor` to process the inference results. + outputTensor = try interpreter.output(at: 0) + //heatsTensor = try interpreter.output(at: 0) + //offsetsTensor = try interpreter.output(at: 1) + + + } catch let error { + os_log( + "Failed to invoke the interpreter with error: %s", type: .error, + error.localizedDescription) + return + } + } + + /// Returns value within [0,1]. + private func sigmoid(_ x: Float32) -> Float32 { + return (1.0 / (1.0 + exp(-x))) + } +} + +// MARK: - Data types for inference result +struct KeyPoint { + var bodyPart: BodyPart = BodyPart.NOSE + var position: CGPoint = CGPoint() + var score: Float = 0.0 +} + +struct Line { + let from: CGPoint + let to: CGPoint +} + +struct Times { + var preprocessing: Double + var inference: Double + var postprocessing: Double +} + +struct Result { + var dots: [CGPoint] + var lines: [Line] + var score: Float +} + +enum BodyPart: String, CaseIterable { + case NOSE = "nose" + case LEFT_EYE = "left eye" + case RIGHT_EYE = "right eye" + case LEFT_EAR = "left ear" + case RIGHT_EAR = "right ear" + case LEFT_SHOULDER = "left shoulder" + case RIGHT_SHOULDER = "right shoulder" + case LEFT_ELBOW = "left elbow" + case RIGHT_ELBOW = "right elbow" + case LEFT_WRIST = "left wrist" + case RIGHT_WRIST = "right wrist" + case LEFT_HIP = "left hip" + case RIGHT_HIP = "right hip" + case LEFT_KNEE = "left knee" + case RIGHT_KNEE = "right knee" + case LEFT_ANKLE = "left ankle" + case RIGHT_ANKLE = "right ankle" + + /// List of lines connecting each part. + static let lines = [ + (from: BodyPart.LEFT_WRIST, to: BodyPart.LEFT_ELBOW), + (from: BodyPart.LEFT_ELBOW, to: BodyPart.LEFT_SHOULDER), + (from: BodyPart.LEFT_SHOULDER, to: BodyPart.RIGHT_SHOULDER), + (from: BodyPart.RIGHT_SHOULDER, to: BodyPart.RIGHT_ELBOW), + (from: BodyPart.RIGHT_ELBOW, to: BodyPart.RIGHT_WRIST), + (from: BodyPart.LEFT_SHOULDER, to: BodyPart.LEFT_HIP), + (from: BodyPart.LEFT_HIP, to: BodyPart.RIGHT_HIP), + (from: BodyPart.RIGHT_HIP, to: BodyPart.RIGHT_SHOULDER), + (from: BodyPart.LEFT_HIP, to: BodyPart.LEFT_KNEE), + (from: BodyPart.LEFT_KNEE, to: BodyPart.LEFT_ANKLE), + (from: BodyPart.RIGHT_HIP, to: BodyPart.RIGHT_KNEE), + (from: BodyPart.RIGHT_KNEE, to: BodyPart.RIGHT_ANKLE), + ] +} + +// MARK: - Delegates Enum +enum Delegates: Int, CaseIterable { + case CPU + case Metal + case CoreML + + var description: String { + switch self { + case .CPU: + return "CPU" + case .Metal: + return "GPU" + case .CoreML: + return "NPU" + } + } +} + +// MARK: - Custom Errors +enum PostprocessError: Error { + case missingBodyPart(of: BodyPart) +} + +// MARK: - Information about the model file. +typealias FileInfo = (name: String, extension: String) + +enum Model { + static let file: FileInfo = ( + name: "model_opt", extension: "tflite" + ) + + static let input = (batchSize: 1, height: 256, width: 256, channelSize: 3) + static let output = (batchSize: 1, height: 256, width: 256, channelSize: 1) + static let isQuantized = false +} + + +extension Array { + /// Creates a new array from the bytes of the given unsafe data. + /// + /// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit + /// with no indirection or reference-counting operations; otherwise, copying the raw bytes in + /// the `unsafeData`'s buffer to a new array returns an unsafe copy. + /// - Note: Returns `nil` if `unsafeData.count` is not a multiple of + /// `MemoryLayout.stride`. + /// - Parameter unsafeData: The data containing the bytes to turn into an array. + init?(unsafeData: Data) { + guard unsafeData.count % MemoryLayout.stride == 0 else { return nil } + #if swift(>=5.0) + self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) } + #else + self = unsafeData.withUnsafeBytes { + .init(UnsafeBufferPointer( + start: $0, + count: unsafeData.count / MemoryLayout.stride + )) + } + #endif // swift(>=5.0) + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Launch Screen.storyboard b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Launch Screen.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..a04c79f554777863bd0dc8287bfd60704ce28bf2 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Launch Screen.storyboard @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Main.storyboard b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Main.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..5f5623794bd35b9bb75efd7b7e249fd7357fdfbd --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Main.storyboard @@ -0,0 +1,236 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ViewControllers/ViewController.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ViewControllers/ViewController.swift new file mode 100644 index 0000000000000000000000000000000000000000..fbb51b5a303412c0bbd158d76d025cf88fee6f8f --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ViewControllers/ViewController.swift @@ -0,0 +1,489 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import AVFoundation +import UIKit +import os + + +public struct PixelData { + var a: UInt8 + var r: UInt8 + var g: UInt8 + var b: UInt8 +} + +extension UIImage { + convenience init?(pixels: [PixelData], width: Int, height: Int) { + guard width > 0 && height > 0, pixels.count == width * height else { return nil } + var data = pixels + guard let providerRef = CGDataProvider(data: Data(bytes: &data, count: data.count * MemoryLayout.size) as CFData) + else { return nil } + guard let cgim = CGImage( + width: width, + height: height, + bitsPerComponent: 8, + bitsPerPixel: 32, + bytesPerRow: width * MemoryLayout.size, + space: CGColorSpaceCreateDeviceRGB(), + bitmapInfo: CGBitmapInfo(rawValue: CGImageAlphaInfo.premultipliedFirst.rawValue), + provider: providerRef, + decode: nil, + shouldInterpolate: false, + intent: .defaultIntent) + else { return nil } + self.init(cgImage: cgim) + } +} + + +class ViewController: UIViewController { + // MARK: Storyboards Connections + @IBOutlet weak var previewView: PreviewView! + + //@IBOutlet weak var overlayView: OverlayView! + @IBOutlet weak var overlayView: UIImageView! + + private var imageView : UIImageView = UIImageView(frame:CGRect(x:0, y:0, width:400, height:400)) + + private var imageViewInitialized: Bool = false + + @IBOutlet weak var resumeButton: UIButton! + @IBOutlet weak var cameraUnavailableLabel: UILabel! + + @IBOutlet weak var tableView: UITableView! + + @IBOutlet weak var threadCountLabel: UILabel! + @IBOutlet weak var threadCountStepper: UIStepper! + + @IBOutlet weak var delegatesControl: UISegmentedControl! + + // MARK: ModelDataHandler traits + var threadCount: Int = Constants.defaultThreadCount + var delegate: Delegates = Constants.defaultDelegate + + // MARK: Result Variables + // Inferenced data to render. + private var inferencedData: InferencedData? + + // Minimum score to render the result. + private let minimumScore: Float = 0.5 + + private var avg_latency: Double = 0.0 + + // Relative location of `overlayView` to `previewView`. + private var overlayViewFrame: CGRect? + + private var previewViewFrame: CGRect? + + // MARK: Controllers that manage functionality + // Handles all the camera related functionality + private lazy var cameraCapture = CameraFeedManager(previewView: previewView) + + // Handles all data preprocessing and makes calls to run inference. + private var modelDataHandler: ModelDataHandler? + + // MARK: View Handling Methods + override func viewDidLoad() { + super.viewDidLoad() + + do { + modelDataHandler = try ModelDataHandler() + } catch let error { + fatalError(error.localizedDescription) + } + + cameraCapture.delegate = self + tableView.delegate = self + tableView.dataSource = self + + // MARK: UI Initialization + // Setup thread count stepper with white color. + // https://forums.developer.apple.com/thread/121495 + threadCountStepper.setDecrementImage( + threadCountStepper.decrementImage(for: .normal), for: .normal) + threadCountStepper.setIncrementImage( + threadCountStepper.incrementImage(for: .normal), for: .normal) + // Setup initial stepper value and its label. + threadCountStepper.value = Double(Constants.defaultThreadCount) + threadCountLabel.text = Constants.defaultThreadCount.description + + // Setup segmented controller's color. + delegatesControl.setTitleTextAttributes( + [NSAttributedString.Key.foregroundColor: UIColor.lightGray], + for: .normal) + delegatesControl.setTitleTextAttributes( + [NSAttributedString.Key.foregroundColor: UIColor.black], + for: .selected) + // Remove existing segments to initialize it with `Delegates` entries. + delegatesControl.removeAllSegments() + Delegates.allCases.forEach { delegate in + delegatesControl.insertSegment( + withTitle: delegate.description, + at: delegate.rawValue, + animated: false) + } + delegatesControl.selectedSegmentIndex = 0 + } + + override func viewWillAppear(_ animated: Bool) { + super.viewWillAppear(animated) + + cameraCapture.checkCameraConfigurationAndStartSession() + } + + override func viewWillDisappear(_ animated: Bool) { + cameraCapture.stopSession() + } + + override func viewDidLayoutSubviews() { + overlayViewFrame = overlayView.frame + previewViewFrame = previewView.frame + } + + // MARK: Button Actions + @IBAction func didChangeThreadCount(_ sender: UIStepper) { + let changedCount = Int(sender.value) + if threadCountLabel.text == changedCount.description { + return + } + + do { + modelDataHandler = try ModelDataHandler(threadCount: changedCount, delegate: delegate) + } catch let error { + fatalError(error.localizedDescription) + } + threadCount = changedCount + threadCountLabel.text = changedCount.description + os_log("Thread count is changed to: %d", threadCount) + } + + @IBAction func didChangeDelegate(_ sender: UISegmentedControl) { + guard let changedDelegate = Delegates(rawValue: delegatesControl.selectedSegmentIndex) else { + fatalError("Unexpected value from delegates segemented controller.") + } + do { + modelDataHandler = try ModelDataHandler(threadCount: threadCount, delegate: changedDelegate) + } catch let error { + fatalError(error.localizedDescription) + } + delegate = changedDelegate + os_log("Delegate is changed to: %s", delegate.description) + } + + @IBAction func didTapResumeButton(_ sender: Any) { + cameraCapture.resumeInterruptedSession { complete in + + if complete { + self.resumeButton.isHidden = true + self.cameraUnavailableLabel.isHidden = true + } else { + self.presentUnableToResumeSessionAlert() + } + } + } + + func presentUnableToResumeSessionAlert() { + let alert = UIAlertController( + title: "Unable to Resume Session", + message: "There was an error while attempting to resume session.", + preferredStyle: .alert + ) + alert.addAction(UIAlertAction(title: "OK", style: .default, handler: nil)) + + self.present(alert, animated: true) + } +} + +// MARK: - CameraFeedManagerDelegate Methods +extension ViewController: CameraFeedManagerDelegate { + func cameraFeedManager(_ manager: CameraFeedManager, didOutput pixelBuffer: CVPixelBuffer) { + runModel(on: pixelBuffer) + } + + // MARK: Session Handling Alerts + func cameraFeedManagerDidEncounterSessionRunTimeError(_ manager: CameraFeedManager) { + // Handles session run time error by updating the UI and providing a button if session can be + // manually resumed. + self.resumeButton.isHidden = false + } + + func cameraFeedManager( + _ manager: CameraFeedManager, sessionWasInterrupted canResumeManually: Bool + ) { + // Updates the UI when session is interupted. + if canResumeManually { + self.resumeButton.isHidden = false + } else { + self.cameraUnavailableLabel.isHidden = false + } + } + + func cameraFeedManagerDidEndSessionInterruption(_ manager: CameraFeedManager) { + // Updates UI once session interruption has ended. + self.cameraUnavailableLabel.isHidden = true + self.resumeButton.isHidden = true + } + + func presentVideoConfigurationErrorAlert(_ manager: CameraFeedManager) { + let alertController = UIAlertController( + title: "Confirguration Failed", message: "Configuration of camera has failed.", + preferredStyle: .alert) + let okAction = UIAlertAction(title: "OK", style: .cancel, handler: nil) + alertController.addAction(okAction) + + present(alertController, animated: true, completion: nil) + } + + func presentCameraPermissionsDeniedAlert(_ manager: CameraFeedManager) { + let alertController = UIAlertController( + title: "Camera Permissions Denied", + message: + "Camera permissions have been denied for this app. You can change this by going to Settings", + preferredStyle: .alert) + + let cancelAction = UIAlertAction(title: "Cancel", style: .cancel, handler: nil) + let settingsAction = UIAlertAction(title: "Settings", style: .default) { action in + if let url = URL.init(string: UIApplication.openSettingsURLString) { + UIApplication.shared.open(url, options: [:], completionHandler: nil) + } + } + + alertController.addAction(cancelAction) + alertController.addAction(settingsAction) + + present(alertController, animated: true, completion: nil) + } + + @objc func runModel(on pixelBuffer: CVPixelBuffer) { + guard let overlayViewFrame = overlayViewFrame, let previewViewFrame = previewViewFrame + else { + return + } + // To put `overlayView` area as model input, transform `overlayViewFrame` following transform + // from `previewView` to `pixelBuffer`. `previewView` area is transformed to fit in + // `pixelBuffer`, because `pixelBuffer` as a camera output is resized to fill `previewView`. + // https://developer.apple.com/documentation/avfoundation/avlayervideogravity/1385607-resizeaspectfill + let modelInputRange = overlayViewFrame.applying( + previewViewFrame.size.transformKeepAspect(toFitIn: pixelBuffer.size)) + + // Run Midas model. + guard + let (result, width, height, times) = self.modelDataHandler?.runMidas( + on: pixelBuffer, + from: modelInputRange, + to: overlayViewFrame.size) + else { + os_log("Cannot get inference result.", type: .error) + return + } + + if avg_latency == 0 { + avg_latency = times.inference + } else { + avg_latency = times.inference*0.1 + avg_latency*0.9 + } + + // Udpate `inferencedData` to render data in `tableView`. + inferencedData = InferencedData(score: Float(avg_latency), times: times) + + //let height = 256 + //let width = 256 + + let outputs = result + let outputs_size = width * height; + + var multiplier : Float = 1.0; + + let max_val : Float = outputs.max() ?? 0 + let min_val : Float = outputs.min() ?? 0 + + if((max_val - min_val) > 0) { + multiplier = 255 / (max_val - min_val); + } + + // Draw result. + DispatchQueue.main.async { + self.tableView.reloadData() + + var pixels: [PixelData] = .init(repeating: .init(a: 255, r: 0, g: 0, b: 0), count: width * height) + + for i in pixels.indices { + //if(i < 1000) + //{ + let val = UInt8((outputs[i] - min_val) * multiplier) + + pixels[i].r = val + pixels[i].g = val + pixels[i].b = val + //} + } + + + /* + pixels[i].a = 255 + pixels[i].r = .random(in: 0...255) + pixels[i].g = .random(in: 0...255) + pixels[i].b = .random(in: 0...255) + } + */ + + DispatchQueue.main.async { + let image = UIImage(pixels: pixels, width: width, height: height) + + self.imageView.image = image + + if (self.imageViewInitialized == false) { + self.imageViewInitialized = true + self.overlayView.addSubview(self.imageView) + self.overlayView.setNeedsDisplay() + } + } + + /* + let image = UIImage(pixels: pixels, width: width, height: height) + + var imageView : UIImageView + imageView = UIImageView(frame:CGRect(x:0, y:0, width:400, height:400)); + imageView.image = image + self.overlayView.addSubview(imageView) + self.overlayView.setNeedsDisplay() + */ + } + } +/* + func drawResult(of result: Result) { + self.overlayView.dots = result.dots + self.overlayView.lines = result.lines + self.overlayView.setNeedsDisplay() + } + + func clearResult() { + self.overlayView.clear() + self.overlayView.setNeedsDisplay() + } + */ + +} + + +// MARK: - TableViewDelegate, TableViewDataSource Methods +extension ViewController: UITableViewDelegate, UITableViewDataSource { + func numberOfSections(in tableView: UITableView) -> Int { + return InferenceSections.allCases.count + } + + func tableView(_ tableView: UITableView, numberOfRowsInSection section: Int) -> Int { + guard let section = InferenceSections(rawValue: section) else { + return 0 + } + + return section.subcaseCount + } + + func tableView(_ tableView: UITableView, cellForRowAt indexPath: IndexPath) -> UITableViewCell { + let cell = tableView.dequeueReusableCell(withIdentifier: "InfoCell") as! InfoCell + guard let section = InferenceSections(rawValue: indexPath.section) else { + return cell + } + guard let data = inferencedData else { return cell } + + var fieldName: String + var info: String + + switch section { + case .Score: + fieldName = section.description + info = String(format: "%.3f", data.score) + case .Time: + guard let row = ProcessingTimes(rawValue: indexPath.row) else { + return cell + } + var time: Double + switch row { + case .InferenceTime: + time = data.times.inference + } + fieldName = row.description + info = String(format: "%.2fms", time) + } + + cell.fieldNameLabel.text = fieldName + cell.infoLabel.text = info + + return cell + } + + func tableView(_ tableView: UITableView, heightForRowAt indexPath: IndexPath) -> CGFloat { + guard let section = InferenceSections(rawValue: indexPath.section) else { + return 0 + } + + var height = Traits.normalCellHeight + if indexPath.row == section.subcaseCount - 1 { + height = Traits.separatorCellHeight + Traits.bottomSpacing + } + return height + } + +} + +// MARK: - Private enums +/// UI coinstraint values +fileprivate enum Traits { + static let normalCellHeight: CGFloat = 35.0 + static let separatorCellHeight: CGFloat = 25.0 + static let bottomSpacing: CGFloat = 30.0 +} + +fileprivate struct InferencedData { + var score: Float + var times: Times +} + +/// Type of sections in Info Cell +fileprivate enum InferenceSections: Int, CaseIterable { + case Score + case Time + + var description: String { + switch self { + case .Score: + return "Average" + case .Time: + return "Processing Time" + } + } + + var subcaseCount: Int { + switch self { + case .Score: + return 1 + case .Time: + return ProcessingTimes.allCases.count + } + } +} + +/// Type of processing times in Time section in Info Cell +fileprivate enum ProcessingTimes: Int, CaseIterable { + case InferenceTime + + var description: String { + switch self { + case .InferenceTime: + return "Inference Time" + } + } +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Views/OverlayView.swift b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Views/OverlayView.swift new file mode 100644 index 0000000000000000000000000000000000000000..3b53910b57563b6a195fd53321fa2a24ebaf3d3f --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Views/OverlayView.swift @@ -0,0 +1,63 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import UIKit + +/// UIView for rendering inference output. +class OverlayView: UIView { + + var dots = [CGPoint]() + var lines = [Line]() + + override func draw(_ rect: CGRect) { + for dot in dots { + drawDot(of: dot) + } + for line in lines { + drawLine(of: line) + } + } + + func drawDot(of dot: CGPoint) { + let dotRect = CGRect( + x: dot.x - Traits.dot.radius / 2, y: dot.y - Traits.dot.radius / 2, + width: Traits.dot.radius, height: Traits.dot.radius) + let dotPath = UIBezierPath(ovalIn: dotRect) + + Traits.dot.color.setFill() + dotPath.fill() + } + + func drawLine(of line: Line) { + let linePath = UIBezierPath() + linePath.move(to: CGPoint(x: line.from.x, y: line.from.y)) + linePath.addLine(to: CGPoint(x: line.to.x, y: line.to.y)) + linePath.close() + + linePath.lineWidth = Traits.line.width + Traits.line.color.setStroke() + + linePath.stroke() + } + + func clear() { + self.dots = [] + self.lines = [] + } +} + +private enum Traits { + static let dot = (radius: CGFloat(5), color: UIColor.orange) + static let line = (width: CGFloat(1.0), color: UIColor.orange) +} diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Podfile b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Podfile new file mode 100644 index 0000000000000000000000000000000000000000..5e9461fc96dbbe3c22ca6bbf2bfd7df3981b9462 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Podfile @@ -0,0 +1,12 @@ +# Uncomment the next line to define a global platform for your project + platform :ios, '12.0' + +target 'Midas' do + # Comment the next line if you're not using Swift and don't want to use dynamic frameworks + use_frameworks! + + # Pods for Midas + pod 'TensorFlowLiteSwift', '~> 0.0.1-nightly' + pod 'TensorFlowLiteSwift/CoreML', '~> 0.0.1-nightly' + pod 'TensorFlowLiteSwift/Metal', '~> 0.0.1-nightly' +end diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/README.md b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7b8eb29feaa21e67814b035dbd5c5fb2c62a4151 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/README.md @@ -0,0 +1,105 @@ +# Tensorflow Lite MiDaS iOS Example + +### Requirements + +- XCode 11.0 or above +- iOS 12.0 or above, [iOS 14 breaks the NPU Delegate](https://github.com/tensorflow/tensorflow/issues/43339) +- TensorFlow 2.4.0, TensorFlowLiteSwift -> 0.0.1-nightly + +## Quick Start with a MiDaS Example + +MiDaS is a neural network to compute depth from a single image. It uses TensorFlowLiteSwift / C++ libraries on iOS. The code is written in Swift. + +Paper: https://arxiv.org/abs/1907.01341 + +> Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +> René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + +### Install TensorFlow + +Set default python version to python3: + +``` +echo 'export PATH=/usr/local/opt/python/libexec/bin:$PATH' >> ~/.zshenv +echo 'alias python=python3' >> ~/.zshenv +echo 'alias pip=pip3' >> ~/.zshenv +``` + +Install TensorFlow + +```shell +pip install tensorflow +``` + +### Install TensorFlowLiteSwift via Cocoapods + +Set required TensorFlowLiteSwift version in the file (`0.0.1-nightly` is recommended): https://github.com/isl-org/MiDaS/blob/master/mobile/ios/Podfile#L9 + +Install: brew, ruby, cocoapods + +``` +ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)" +brew install mc rbenv ruby-build +sudo gem install cocoapods +``` + + +The TensorFlowLiteSwift library is available in [Cocoapods](https://cocoapods.org/), to integrate it to our project, we can run in the root directory of the project: + +```ruby +pod install +``` + +Now open the `Midas.xcworkspace` file in XCode, select your iPhone device (XCode->Product->Destination->iPhone) and launch it (cmd + R). If everything works well, you should see a real-time depth map from your camera. + +### Model + +The TensorFlow (TFlite) model `midas.tflite` is in the folder `/Midas/Model` + + +To use another model, you should convert it from TensorFlow saved-model to TFlite model (so that it can be deployed): + +```python +saved_model_export_dir = "./saved_model" +converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_export_dir) +tflite_model = converter.convert() +open(model_tflite_name, "wb").write("model.tflite") +``` + +### Setup XCode + +* Open directory `.xcworkspace` from the XCode + +* Press on your ProjectName (left-top corner) -> change Bundle Identifier to `com.midas.tflite-npu` or something like this (it should be unique) + +* select your Developer Team (your should be signed-in by using your AppleID) + +* Connect your iPhone (if you want to run it on real device instead of simulator), select your iPhone device (XCode->Product->Destination->iPhone) + +* Click in the XCode: Product -> Run + +* On your iPhone device go to the: Settings -> General -> Device Management (or Profiles) -> Apple Development -> Trust Apple Development + +---- + +Original repository: https://github.com/isl-org/MiDaS + + +### Examples: + +| ![photo_2020-09-27_17-43-20](https://user-images.githubusercontent.com/4096485/94367804-9610de80-00e9-11eb-8a23-8b32a6f52d41.jpg) | ![photo_2020-09-27_17-49-22](https://user-images.githubusercontent.com/4096485/94367974-7201cd00-00ea-11eb-8e0a-68eb9ea10f63.jpg) | ![photo_2020-09-27_17-52-30](https://user-images.githubusercontent.com/4096485/94367976-729a6380-00ea-11eb-8ce0-39d3e26dd550.jpg) | ![photo_2020-09-27_17-43-21](https://user-images.githubusercontent.com/4096485/94367807-97420b80-00e9-11eb-9dcd-848ad9e89e03.jpg) | +|---|---|---|---| + +## LICENSE + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/RunScripts/download_models.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/RunScripts/download_models.sh new file mode 100644 index 0000000000000000000000000000000000000000..d737b39d966278f5c6bc29802526ab86f8473de4 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/RunScripts/download_models.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Download TF Lite model from the internet if it does not exist. + +TFLITE_MODEL="model_opt.tflite" +TFLITE_FILE="Midas/Model/${TFLITE_MODEL}" +MODEL_SRC="https://github.com/isl-org/MiDaS/releases/download/v2/${TFLITE_MODEL}" + +if test -f "${TFLITE_FILE}"; then + echo "INFO: TF Lite model already exists. Skip downloading and use the local model." +else + curl --create-dirs -o "${TFLITE_FILE}" -LJO "${MODEL_SRC}" + echo "INFO: Downloaded TensorFlow Lite model to ${TFLITE_FILE}." +fi + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/output/.placeholder b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/output/.placeholder new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/LICENSE b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6606ec028d1c629986e7019fe3564f5b4bfe425d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Alexey + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/README.md b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1d43c2606767798ee46b34292e0483197424ec23 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/README.md @@ -0,0 +1,131 @@ +# MiDaS for ROS1 by using LibTorch in C++ + +### Requirements + +- Ubuntu 17.10 / 18.04 / 20.04, Debian Stretch +- ROS Melodic for Ubuntu (17.10 / 18.04) / Debian Stretch, ROS Noetic for Ubuntu 20.04 +- C++11 +- LibTorch >= 1.6 + +## Quick Start with a MiDaS Example + +MiDaS is a neural network to compute depth from a single image. + +* input from `image_topic`: `sensor_msgs/Image` - `RGB8` image with any shape +* output to `midas_topic`: `sensor_msgs/Image` - `TYPE_32FC1` inverse relative depth maps in range [0 - 255] with original size and channels=1 + +### Install Dependecies + +* install ROS Melodic for Ubuntu 17.10 / 18.04: +```bash +wget https://raw.githubusercontent.com/isl-org/MiDaS/master/ros/additions/install_ros_melodic_ubuntu_17_18.sh +./install_ros_melodic_ubuntu_17_18.sh +``` + +or Noetic for Ubuntu 20.04: + +```bash +wget https://raw.githubusercontent.com/isl-org/MiDaS/master/ros/additions/install_ros_noetic_ubuntu_20.sh +./install_ros_noetic_ubuntu_20.sh +``` + + +* install LibTorch 1.7 with CUDA 11.0: + +On **Jetson (ARM)**: +```bash +wget https://nvidia.box.com/shared/static/wa34qwrwtk9njtyarwt5nvo6imenfy26.whl -O torch-1.7.0-cp36-cp36m-linux_aarch64.whl +sudo apt-get install python3-pip libopenblas-base libopenmpi-dev +pip3 install Cython +pip3 install numpy torch-1.7.0-cp36-cp36m-linux_aarch64.whl +``` +Or compile LibTorch from source: https://github.com/pytorch/pytorch#from-source + +On **Linux (x86_64)**: +```bash +cd ~/ +wget https://download.pytorch.org/libtorch/cu110/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu110.zip +unzip libtorch-cxx11-abi-shared-with-deps-1.7.0+cu110.zip +``` + +* create symlink for OpenCV: + +```bash +sudo ln -s /usr/include/opencv4 /usr/include/opencv +``` + +* download and install MiDaS: + +```bash +source ~/.bashrc +cd ~/ +mkdir catkin_ws +cd catkin_ws +git clone https://github.com/isl-org/MiDaS +mkdir src +cp -r MiDaS/ros/* src + +chmod +x src/additions/*.sh +chmod +x src/*.sh +chmod +x src/midas_cpp/scripts/*.py +cp src/additions/do_catkin_make.sh ./do_catkin_make.sh +./do_catkin_make.sh +./src/additions/downloads.sh +``` + +### Usage + +* run only `midas` node: `~/catkin_ws/src/launch_midas_cpp.sh` + +#### Test + +* Test - capture video and show result in the window: + * place any `test.mp4` video file to the directory `~/catkin_ws/src/` + * run `midas` node: `~/catkin_ws/src/launch_midas_cpp.sh` + * run test nodes in another terminal: `cd ~/catkin_ws/src && ./run_talker_listener_test.sh` and wait 30 seconds + + (to use Python 2, run command `sed -i 's/python3/python2/' ~/catkin_ws/src/midas_cpp/scripts/*.py` ) + +## Mobile version of MiDaS - Monocular Depth Estimation + +### Accuracy + +* MiDaS v2 small - ResNet50 default-decoder 384x384 +* MiDaS v2.1 small - EfficientNet-Lite3 small-decoder 256x256 + +**Zero-shot error** (the lower - the better): + +| Model | DIW WHDR | Eth3d AbsRel | Sintel AbsRel | Kitti δ>1.25 | NyuDepthV2 δ>1.25 | TUM δ>1.25 | +|---|---|---|---|---|---|---| +| MiDaS v2 small 384x384 | **0.1248** | 0.1550 | **0.3300** | **21.81** | 15.73 | 17.00 | +| MiDaS v2.1 small 256x256 | 0.1344 | **0.1344** | 0.3370 | 29.27 | **13.43** | **14.53** | +| Relative improvement, % | -8 % | **+13 %** | -2 % | -34 % | **+15 %** | **+15 %** | + +None of Train/Valid/Test subsets of datasets (DIW, Eth3d, Sintel, Kitti, NyuDepthV2, TUM) were not involved in Training or Fine Tuning. + +### Inference speed (FPS) on nVidia GPU + +Inference speed excluding pre and post processing, batch=1, **Frames Per Second** (the higher - the better): + +| Model | Jetson Nano, FPS | RTX 2080Ti, FPS | +|---|---|---| +| MiDaS v2 small 384x384 | 1.6 | 117 | +| MiDaS v2.1 small 256x256 | 8.1 | 232 | +| SpeedUp, X times | **5x** | **2x** | + +### Citation + +This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3): + +>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + +Please cite our paper if you use this code or any of the models: +``` +@article{Ranftl2020, + author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun}, + title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer}, + journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, + year = {2020}, +} +``` diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/do_catkin_make.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/do_catkin_make.sh new file mode 100644 index 0000000000000000000000000000000000000000..0d416fc00282aab146326bbba12a9274e1ba29b8 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/do_catkin_make.sh @@ -0,0 +1,5 @@ +mkdir src +catkin_make +source devel/setup.bash +echo $ROS_PACKAGE_PATH +chmod +x ./devel/setup.bash diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/downloads.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/downloads.sh new file mode 100644 index 0000000000000000000000000000000000000000..9c967d4e2dc7997da26399a063b5a54ecc314eb1 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/downloads.sh @@ -0,0 +1,5 @@ +mkdir ~/.ros +wget https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small-traced.pt +cp ./model-small-traced.pt ~/.ros/model-small-traced.pt + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_melodic_ubuntu_17_18.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_melodic_ubuntu_17_18.sh new file mode 100644 index 0000000000000000000000000000000000000000..b868112631e9d9bc7bccb601407dfc857b8a99d5 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_melodic_ubuntu_17_18.sh @@ -0,0 +1,34 @@ +#@title { display-mode: "code" } + +#from http://wiki.ros.org/indigo/Installation/Ubuntu + +#1.2 Setup sources.list +sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list' + +# 1.3 Setup keys +sudo apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654 +sudo apt-key adv --keyserver 'hkp://ha.pool.sks-keyservers.net:80' --recv-key 421C365BD9FF1F717815A3895523BAEEB01FA116 + +curl -sSL 'http://keyserver.ubuntu.com/pks/lookup?op=get&search=0xC1CF6E31E6BADE8868B172B4F42ED6FBAB17C654' | sudo apt-key add - + +# 1.4 Installation +sudo apt-get update +sudo apt-get upgrade + +# Desktop-Full Install: +sudo apt-get install ros-melodic-desktop-full + +printf "\nsource /opt/ros/melodic/setup.bash\n" >> ~/.bashrc + +# 1.5 Initialize rosdep +sudo rosdep init +rosdep update + + +# 1.7 Getting rosinstall (python) +sudo apt-get install python-rosinstall +sudo apt-get install python-catkin-tools +sudo apt-get install python-rospy +sudo apt-get install python-rosdep +sudo apt-get install python-roscd +sudo apt-get install python-pip \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_noetic_ubuntu_20.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_noetic_ubuntu_20.sh new file mode 100644 index 0000000000000000000000000000000000000000..d73ea1a3d92359819167d735a92d2a650b9bc245 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_noetic_ubuntu_20.sh @@ -0,0 +1,33 @@ +#@title { display-mode: "code" } + +#from http://wiki.ros.org/indigo/Installation/Ubuntu + +#1.2 Setup sources.list +sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list' + +# 1.3 Setup keys +sudo apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654 + +curl -sSL 'http://keyserver.ubuntu.com/pks/lookup?op=get&search=0xC1CF6E31E6BADE8868B172B4F42ED6FBAB17C654' | sudo apt-key add - + +# 1.4 Installation +sudo apt-get update +sudo apt-get upgrade + +# Desktop-Full Install: +sudo apt-get install ros-noetic-desktop-full + +printf "\nsource /opt/ros/noetic/setup.bash\n" >> ~/.bashrc + +# 1.5 Initialize rosdep +sudo rosdep init +rosdep update + + +# 1.7 Getting rosinstall (python) +sudo apt-get install python3-rosinstall +sudo apt-get install python3-catkin-tools +sudo apt-get install python3-rospy +sudo apt-get install python3-rosdep +sudo apt-get install python3-roscd +sudo apt-get install python3-pip \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/make_package_cpp.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/make_package_cpp.sh new file mode 100644 index 0000000000000000000000000000000000000000..d0ef6073a9c9ce40744e1c81d557c1c68255b95e --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/make_package_cpp.sh @@ -0,0 +1,16 @@ +cd ~/catkin_ws/src +catkin_create_pkg midas_cpp std_msgs roscpp cv_bridge sensor_msgs image_transport +cd ~/catkin_ws +catkin_make + +chmod +x ~/catkin_ws/devel/setup.bash +printf "\nsource ~/catkin_ws/devel/setup.bash" >> ~/.bashrc +source ~/catkin_ws/devel/setup.bash + + +sudo rosdep init +rosdep update +#rospack depends1 midas_cpp +roscd midas_cpp +#cat package.xml +#rospack depends midas_cpp \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/launch_midas_cpp.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/launch_midas_cpp.sh new file mode 100644 index 0000000000000000000000000000000000000000..5a0d1583fffdc49216c625dfd07af2ae3b01a7a0 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/launch_midas_cpp.sh @@ -0,0 +1,2 @@ +source ~/catkin_ws/devel/setup.bash +roslaunch midas_cpp midas_cpp.launch model_name:="model-small-traced.pt" input_topic:="image_topic" output_topic:="midas_topic" out_orig_size:="true" \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/CMakeLists.txt b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..885341691d217f9c4c8fcb1e4ff568d87788c7b8 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/CMakeLists.txt @@ -0,0 +1,189 @@ +cmake_minimum_required(VERSION 3.0.2) +project(midas_cpp) + +## Compile as C++11, supported in ROS Kinetic and newer +# add_compile_options(-std=c++11) + +## Find catkin macros and libraries +## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) +## is used, also find other catkin packages +find_package(catkin REQUIRED COMPONENTS + cv_bridge + image_transport + roscpp + rospy + sensor_msgs + std_msgs +) + +## System dependencies are found with CMake's conventions +# find_package(Boost REQUIRED COMPONENTS system) + +list(APPEND CMAKE_PREFIX_PATH "~/libtorch") +list(APPEND CMAKE_PREFIX_PATH "/usr/local/lib/python3.6/dist-packages/torch/lib") +list(APPEND CMAKE_PREFIX_PATH "/usr/local/lib/python2.7/dist-packages/torch/lib") + +if(NOT EXISTS "~/libtorch") + if (EXISTS "/usr/local/lib/python3.6/dist-packages/torch") + include_directories(/usr/local/include) + include_directories(/usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include) + include_directories(/usr/local/lib/python3.6/dist-packages/torch/include) + + link_directories(/usr/local/lib) + link_directories(/usr/local/lib/python3.6/dist-packages/torch/lib) + + set(CMAKE_PREFIX_PATH /usr/local/lib/python3.6/dist-packages/torch) + set(Boost_USE_MULTITHREADED ON) + set(Torch_DIR /usr/local/lib/python3.6/dist-packages/torch) + + elseif (EXISTS "/usr/local/lib/python2.7/dist-packages/torch") + + include_directories(/usr/local/include) + include_directories(/usr/local/lib/python2.7/dist-packages/torch/include/torch/csrc/api/include) + include_directories(/usr/local/lib/python2.7/dist-packages/torch/include) + + link_directories(/usr/local/lib) + link_directories(/usr/local/lib/python2.7/dist-packages/torch/lib) + + set(CMAKE_PREFIX_PATH /usr/local/lib/python2.7/dist-packages/torch) + set(Boost_USE_MULTITHREADED ON) + set(Torch_DIR /usr/local/lib/python2.7/dist-packages/torch) + endif() +endif() + + + +find_package(Torch REQUIRED) +find_package(OpenCV REQUIRED) +include_directories( ${OpenCV_INCLUDE_DIRS} ) + +add_executable(midas_cpp src/main.cpp) +target_link_libraries(midas_cpp "${TORCH_LIBRARIES}" "${OpenCV_LIBS} ${catkin_LIBRARIES}") +set_property(TARGET midas_cpp PROPERTY CXX_STANDARD 14) + + + +################################### +## catkin specific configuration ## +################################### +## The catkin_package macro generates cmake config files for your package +## Declare things to be passed to dependent projects +## INCLUDE_DIRS: uncomment this if your package contains header files +## LIBRARIES: libraries you create in this project that dependent projects also need +## CATKIN_DEPENDS: catkin_packages dependent projects also need +## DEPENDS: system dependencies of this project that dependent projects also need +catkin_package( +# INCLUDE_DIRS include +# LIBRARIES midas_cpp +# CATKIN_DEPENDS cv_bridge image_transport roscpp sensor_msgs std_msgs +# DEPENDS system_lib +) + +########### +## Build ## +########### + +## Specify additional locations of header files +## Your package locations should be listed before other locations +include_directories( +# include + ${catkin_INCLUDE_DIRS} +) + +## Declare a C++ library +# add_library(${PROJECT_NAME} +# src/${PROJECT_NAME}/midas_cpp.cpp +# ) + +## Add cmake target dependencies of the library +## as an example, code may need to be generated before libraries +## either from message generation or dynamic reconfigure +# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Declare a C++ executable +## With catkin_make all packages are built within a single CMake context +## The recommended prefix ensures that target names across packages don't collide +# add_executable(${PROJECT_NAME}_node src/midas_cpp_node.cpp) + +## Rename C++ executable without prefix +## The above recommended prefix causes long target names, the following renames the +## target back to the shorter version for ease of user use +## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node" +# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "") + +## Add cmake target dependencies of the executable +## same as for the library above +# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Specify libraries to link a library or executable target against +# target_link_libraries(${PROJECT_NAME}_node +# ${catkin_LIBRARIES} +# ) + +############# +## Install ## +############# + +# all install targets should use catkin DESTINATION variables +# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html + +## Mark executable scripts (Python etc.) for installation +## in contrast to setup.py, you can choose the destination +# catkin_install_python(PROGRAMS +# scripts/my_python_script +# DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark executables for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html +# install(TARGETS ${PROJECT_NAME}_node +# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark libraries for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_libraries.html +# install(TARGETS ${PROJECT_NAME} +# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# RUNTIME DESTINATION ${CATKIN_GLOBAL_BIN_DESTINATION} +# ) + +## Mark cpp header files for installation +# install(DIRECTORY include/${PROJECT_NAME}/ +# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION} +# FILES_MATCHING PATTERN "*.h" +# PATTERN ".svn" EXCLUDE +# ) + +## Mark other files for installation (e.g. launch and bag files, etc.) +# install(FILES +# # myfile1 +# # myfile2 +# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +# ) + +############# +## Testing ## +############# + +## Add gtest based cpp test target and link libraries +# catkin_add_gtest(${PROJECT_NAME}-test test/test_midas_cpp.cpp) +# if(TARGET ${PROJECT_NAME}-test) +# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME}) +# endif() + +## Add folders to be run by python nosetests +# catkin_add_nosetests(test) + +install(TARGETS ${PROJECT_NAME} + ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} + LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} + RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +) + +add_custom_command( + TARGET midas_cpp POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${CMAKE_CURRENT_BINARY_DIR}/midas_cpp + ${CMAKE_SOURCE_DIR}/midas_cpp +) \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_cpp.launch b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_cpp.launch new file mode 100644 index 0000000000000000000000000000000000000000..88e86f42f668e76ad4976ec6794a8cb0f20cac65 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_cpp.launch @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_talker_listener.launch b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_talker_listener.launch new file mode 100644 index 0000000000000000000000000000000000000000..8817a4f4933c56986fe0edc0886b2fded3d3406d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_talker_listener.launch @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/package.xml b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/package.xml new file mode 100644 index 0000000000000000000000000000000000000000..9cac90eba75409bd170f73531c54c83c52ff047a --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/package.xml @@ -0,0 +1,77 @@ + + + midas_cpp + 0.1.0 + The midas_cpp package + + Alexey Bochkovskiy + MIT + https://github.com/isl-org/MiDaS/tree/master/ros + + + + + + + TODO + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + catkin + cv_bridge + image_transport + roscpp + rospy + sensor_msgs + std_msgs + cv_bridge + image_transport + roscpp + rospy + sensor_msgs + std_msgs + cv_bridge + image_transport + roscpp + rospy + sensor_msgs + std_msgs + + + + + + + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener.py new file mode 100644 index 0000000000000000000000000000000000000000..6927ea7a83ac9309e5f883ee974a5dcfa8a2aa3b --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +from __future__ import print_function + +import roslib +#roslib.load_manifest('my_package') +import sys +import rospy +import cv2 +import numpy as np +from std_msgs.msg import String +from sensor_msgs.msg import Image +from cv_bridge import CvBridge, CvBridgeError + +class video_show: + + def __init__(self): + self.show_output = rospy.get_param('~show_output', True) + self.save_output = rospy.get_param('~save_output', False) + self.output_video_file = rospy.get_param('~output_video_file','result.mp4') + # rospy.loginfo(f"Listener - params: show_output={self.show_output}, save_output={self.save_output}, output_video_file={self.output_video_file}") + + self.bridge = CvBridge() + self.image_sub = rospy.Subscriber("midas_topic", Image, self.callback) + + def callback(self, data): + try: + cv_image = self.bridge.imgmsg_to_cv2(data) + except CvBridgeError as e: + print(e) + return + + if cv_image.size == 0: + return + + rospy.loginfo("Listener: Received new frame") + cv_image = cv_image.astype("uint8") + + if self.show_output==True: + cv2.imshow("video_show", cv_image) + cv2.waitKey(10) + + if self.save_output==True: + if self.video_writer_init==False: + fourcc = cv2.VideoWriter_fourcc(*'XVID') + self.out = cv2.VideoWriter(self.output_video_file, fourcc, 25, (cv_image.shape[1], cv_image.shape[0])) + + self.out.write(cv_image) + + + +def main(args): + rospy.init_node('listener', anonymous=True) + ic = video_show() + try: + rospy.spin() + except KeyboardInterrupt: + print("Shutting down") + cv2.destroyAllWindows() + +if __name__ == '__main__': + main(sys.argv) \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener_original.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener_original.py new file mode 100644 index 0000000000000000000000000000000000000000..20e235f6958d644b89383752ab18e9e2275f55e5 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener_original.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +from __future__ import print_function + +import roslib +#roslib.load_manifest('my_package') +import sys +import rospy +import cv2 +import numpy as np +from std_msgs.msg import String +from sensor_msgs.msg import Image +from cv_bridge import CvBridge, CvBridgeError + +class video_show: + + def __init__(self): + self.show_output = rospy.get_param('~show_output', True) + self.save_output = rospy.get_param('~save_output', False) + self.output_video_file = rospy.get_param('~output_video_file','result.mp4') + # rospy.loginfo(f"Listener original - params: show_output={self.show_output}, save_output={self.save_output}, output_video_file={self.output_video_file}") + + self.bridge = CvBridge() + self.image_sub = rospy.Subscriber("image_topic", Image, self.callback) + + def callback(self, data): + try: + cv_image = self.bridge.imgmsg_to_cv2(data) + except CvBridgeError as e: + print(e) + return + + if cv_image.size == 0: + return + + rospy.loginfo("Listener_original: Received new frame") + cv_image = cv_image.astype("uint8") + + if self.show_output==True: + cv2.imshow("video_show_orig", cv_image) + cv2.waitKey(10) + + if self.save_output==True: + if self.video_writer_init==False: + fourcc = cv2.VideoWriter_fourcc(*'XVID') + self.out = cv2.VideoWriter(self.output_video_file, fourcc, 25, (cv_image.shape[1], cv_image.shape[0])) + + self.out.write(cv_image) + + + +def main(args): + rospy.init_node('listener_original', anonymous=True) + ic = video_show() + try: + rospy.spin() + except KeyboardInterrupt: + print("Shutting down") + cv2.destroyAllWindows() + +if __name__ == '__main__': + main(sys.argv) \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/talker.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/talker.py new file mode 100644 index 0000000000000000000000000000000000000000..8219cc8632484a2efd02984347c615efad6b78b2 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/talker.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 + + +import roslib +#roslib.load_manifest('my_package') +import sys +import rospy +import cv2 +from std_msgs.msg import String +from sensor_msgs.msg import Image +from cv_bridge import CvBridge, CvBridgeError + + +def talker(): + rospy.init_node('talker', anonymous=True) + + use_camera = rospy.get_param('~use_camera', False) + input_video_file = rospy.get_param('~input_video_file','test.mp4') + # rospy.loginfo(f"Talker - params: use_camera={use_camera}, input_video_file={input_video_file}") + + # rospy.loginfo("Talker: Trying to open a video stream") + if use_camera == True: + cap = cv2.VideoCapture(0) + else: + cap = cv2.VideoCapture(input_video_file) + + pub = rospy.Publisher('image_topic', Image, queue_size=1) + rate = rospy.Rate(30) # 30hz + bridge = CvBridge() + + while not rospy.is_shutdown(): + ret, cv_image = cap.read() + if ret==False: + print("Talker: Video is over") + rospy.loginfo("Video is over") + return + + try: + image = bridge.cv2_to_imgmsg(cv_image, "bgr8") + except CvBridgeError as e: + rospy.logerr("Talker: cv2image conversion failed: ", e) + print(e) + continue + + rospy.loginfo("Talker: Publishing frame") + pub.publish(image) + rate.sleep() + +if __name__ == '__main__': + try: + talker() + except rospy.ROSInterruptException: + pass diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/src/main.cpp b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/src/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e4fc72c6955f66af71c9cb1fc7a7b1f643129685 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/src/main.cpp @@ -0,0 +1,285 @@ +#include +#include +#include +#include + +#include + +#include // One-stop header. + +#include +#include +#include +#include + +#include +#include + +// includes for OpenCV >= 3.x +#ifndef CV_VERSION_EPOCH +#include +#include +#include +#endif + +// OpenCV includes for OpenCV 2.x +#ifdef CV_VERSION_EPOCH +#include +#include +#include +#include +#endif + +static const std::string OPENCV_WINDOW = "Image window"; + +class Midas +{ + ros::NodeHandle nh_; + image_transport::ImageTransport it_; + image_transport::Subscriber image_sub_; + image_transport::Publisher image_pub_; + + torch::jit::script::Module module; + torch::Device device; + + auto ToTensor(cv::Mat img, bool show_output = false, bool unsqueeze = false, int unsqueeze_dim = 0) + { + //std::cout << "image shape: " << img.size() << std::endl; + at::Tensor tensor_image = torch::from_blob(img.data, { img.rows, img.cols, 3 }, at::kByte); + + if (unsqueeze) + { + tensor_image.unsqueeze_(unsqueeze_dim); + //std::cout << "tensors new shape: " << tensor_image.sizes() << std::endl; + } + + if (show_output) + { + std::cout << tensor_image.slice(2, 0, 1) << std::endl; + } + //std::cout << "tenor shape: " << tensor_image.sizes() << std::endl; + return tensor_image; + } + + auto ToInput(at::Tensor tensor_image) + { + // Create a vector of inputs. + return std::vector{tensor_image}; + } + + auto ToCvImage(at::Tensor tensor, int cv_type = CV_8UC3) + { + int width = tensor.sizes()[0]; + int height = tensor.sizes()[1]; + try + { + cv::Mat output_mat; + if (cv_type == CV_8UC4 || cv_type == CV_8UC3 || cv_type == CV_8UC2 || cv_type == CV_8UC1) { + cv::Mat cv_image(cv::Size{ height, width }, cv_type, tensor.data_ptr()); + output_mat = cv_image; + } + else if (cv_type == CV_32FC4 || cv_type == CV_32FC3 || cv_type == CV_32FC2 || cv_type == CV_32FC1) { + cv::Mat cv_image(cv::Size{ height, width }, cv_type, tensor.data_ptr()); + output_mat = cv_image; + } + else if (cv_type == CV_64FC4 || cv_type == CV_64FC3 || cv_type == CV_64FC2 || cv_type == CV_64FC1) { + cv::Mat cv_image(cv::Size{ height, width }, cv_type, tensor.data_ptr()); + output_mat = cv_image; + } + + //show_image(output_mat, "converted image from tensor"); + return output_mat.clone(); + } + catch (const c10::Error& e) + { + std::cout << "an error has occured : " << e.msg() << std::endl; + } + return cv::Mat(height, width, CV_8UC3); + } + + std::string input_topic, output_topic, model_name; + bool out_orig_size; + int net_width, net_height; + torch::NoGradGuard guard; + at::Tensor mean, std; + at::Tensor output, tensor; + +public: + Midas() + : nh_(), it_(nh_), device(torch::Device(torch::kCPU)) + { + ros::param::param("~input_topic", input_topic, "image_topic"); + ros::param::param("~output_topic", output_topic, "midas_topic"); + ros::param::param("~model_name", model_name, "model-small-traced.pt"); + ros::param::param("~out_orig_size", out_orig_size, true); + ros::param::param("~net_width", net_width, 256); + ros::param::param("~net_height", net_height, 256); + + std::cout << ", input_topic = " << input_topic << + ", output_topic = " << output_topic << + ", model_name = " << model_name << + ", out_orig_size = " << out_orig_size << + ", net_width = " << net_width << + ", net_height = " << net_height << + std::endl; + + // Subscrive to input video feed and publish output video feed + image_sub_ = it_.subscribe(input_topic, 1, &Midas::imageCb, this); + image_pub_ = it_.advertise(output_topic, 1); + + std::cout << "Try to load torchscript model \n"; + + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + module = torch::jit::load(model_name); + } + catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + exit(0); + } + + std::cout << "ok\n"; + + try { + module.eval(); + torch::jit::getProfilingMode() = false; + torch::jit::setGraphExecutorOptimize(true); + + mean = torch::tensor({ 0.485, 0.456, 0.406 }); + std = torch::tensor({ 0.229, 0.224, 0.225 }); + + if (torch::hasCUDA()) { + std::cout << "cuda is available" << std::endl; + at::globalContext().setBenchmarkCuDNN(true); + device = torch::Device(torch::kCUDA); + module.to(device); + mean = mean.to(device); + std = std.to(device); + } + } + catch (const c10::Error& e) + { + std::cerr << " module initialization: " << e.msg() << std::endl; + } + } + + ~Midas() + { + } + + void imageCb(const sensor_msgs::ImageConstPtr& msg) + { + cv_bridge::CvImagePtr cv_ptr; + try + { + // sensor_msgs::Image to cv::Mat + cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::RGB8); + } + catch (cv_bridge::Exception& e) + { + ROS_ERROR("cv_bridge exception: %s", e.what()); + return; + } + + // pre-processing + auto tensor_cpu = ToTensor(cv_ptr->image); // OpenCV-image -> Libtorch-tensor + + try { + tensor = tensor_cpu.to(device); // move to device (CPU or GPU) + + tensor = tensor.toType(c10::kFloat); + tensor = tensor.permute({ 2, 0, 1 }); // HWC -> CHW + tensor = tensor.unsqueeze(0); + tensor = at::upsample_bilinear2d(tensor, { net_height, net_width }, true); // resize + tensor = tensor.squeeze(0); + tensor = tensor.permute({ 1, 2, 0 }); // CHW -> HWC + + tensor = tensor.div(255).sub(mean).div(std); // normalization + tensor = tensor.permute({ 2, 0, 1 }); // HWC -> CHW + tensor.unsqueeze_(0); // CHW -> NCHW + } + catch (const c10::Error& e) + { + std::cerr << " pre-processing exception: " << e.msg() << std::endl; + return; + } + + auto input_to_net = ToInput(tensor); // input to the network + + // inference + output; + try { + output = module.forward(input_to_net).toTensor(); // run inference + } + catch (const c10::Error& e) + { + std::cerr << " module.forward() exception: " << e.msg() << std::endl; + return; + } + + output = output.detach().to(torch::kF32); + + // move to CPU temporary + at::Tensor output_tmp = output; + output_tmp = output_tmp.to(torch::kCPU); + + // normalization + float min_val = std::numeric_limits::max(); + float max_val = std::numeric_limits::min(); + + for (int i = 0; i < net_width * net_height; ++i) { + float val = output_tmp.data_ptr()[i]; + if (min_val > val) min_val = val; + if (max_val < val) max_val = val; + } + float range_val = max_val - min_val; + + output = output.sub(min_val).div(range_val).mul(255.0F).clamp(0, 255).to(torch::kF32); // .to(torch::kU8); + + // resize to the original size if required + if (out_orig_size) { + try { + output = at::upsample_bilinear2d(output.unsqueeze(0), { cv_ptr->image.size().height, cv_ptr->image.size().width }, true); + output = output.squeeze(0); + } + catch (const c10::Error& e) + { + std::cout << " upsample_bilinear2d() exception: " << e.msg() << std::endl; + return; + } + } + output = output.permute({ 1, 2, 0 }).to(torch::kCPU); + + int cv_type = CV_32FC1; // CV_8UC1; + auto cv_img = ToCvImage(output, cv_type); + + sensor_msgs::Image img_msg; + + try { + // cv::Mat -> sensor_msgs::Image + std_msgs::Header header; // empty header + header.seq = 0; // user defined counter + header.stamp = ros::Time::now();// time + //cv_bridge::CvImage img_bridge = cv_bridge::CvImage(header, sensor_msgs::image_encodings::MONO8, cv_img); + cv_bridge::CvImage img_bridge = cv_bridge::CvImage(header, sensor_msgs::image_encodings::TYPE_32FC1, cv_img); + + img_bridge.toImageMsg(img_msg); // cv_bridge -> sensor_msgs::Image + } + catch (cv_bridge::Exception& e) + { + ROS_ERROR("cv_bridge exception: %s", e.what()); + return; + } + + // Output modified video stream + image_pub_.publish(img_msg); + } +}; + +int main(int argc, char** argv) +{ + ros::init(argc, argv, "midas", ros::init_options::AnonymousName); + Midas ic; + ros::spin(); + return 0; +} \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/run_talker_listener_test.sh b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/run_talker_listener_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..a997c4261072d0d627598fe06a723fcc7522d347 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/run_talker_listener_test.sh @@ -0,0 +1,16 @@ +# place any test.mp4 file near with this file + +# roscore +# rosnode kill -a + +source ~/catkin_ws/devel/setup.bash + +roscore & +P1=$! +rosrun midas_cpp talker.py & +P2=$! +rosrun midas_cpp listener_original.py & +P3=$! +rosrun midas_cpp listener.py & +P4=$! +wait $P1 $P2 $P3 $P4 \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/run.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/run.py new file mode 100644 index 0000000000000000000000000000000000000000..5696ef0547af093713ea416d18edd77d11879d0a --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/run.py @@ -0,0 +1,277 @@ +"""Compute depth maps for images in the input folder. +""" +import os +import glob +import torch +import utils +import cv2 +import argparse +import time + +import numpy as np + +from imutils.video import VideoStream +from midas.model_loader import default_models, load_model + +first_execution = True +def process(device, model, model_type, image, input_size, target_size, optimize, use_camera): + """ + Run the inference and interpolate. + + Args: + device (torch.device): the torch device used + model: the model used for inference + model_type: the type of the model + image: the image fed into the neural network + input_size: the size (width, height) of the neural network input (for OpenVINO) + target_size: the size (width, height) the neural network output is interpolated to + optimize: optimize the model to half-floats on CUDA? + use_camera: is the camera used? + + Returns: + the prediction + """ + global first_execution + + if "openvino" in model_type: + if first_execution or not use_camera: + print(f" Input resized to {input_size[0]}x{input_size[1]} before entering the encoder") + first_execution = False + + sample = [np.reshape(image, (1, 3, *input_size))] + prediction = model(sample)[model.output(0)][0] + prediction = cv2.resize(prediction, dsize=target_size, + interpolation=cv2.INTER_CUBIC) + else: + sample = torch.from_numpy(image).to(device).unsqueeze(0) + + if optimize and device == torch.device("cuda"): + if first_execution: + print(" Optimization to half-floats activated. Use with caution, because models like Swin require\n" + " float precision to work properly and may yield non-finite depth values to some extent for\n" + " half-floats.") + sample = sample.to(memory_format=torch.channels_last) + sample = sample.half() + + if first_execution or not use_camera: + height, width = sample.shape[2:] + print(f" Input resized to {width}x{height} before entering the encoder") + first_execution = False + + prediction = model.forward(sample) + prediction = ( + torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=target_size[::-1], + mode="bicubic", + align_corners=False, + ) + .squeeze() + .cpu() + .numpy() + ) + + return prediction + + +def create_side_by_side(image, depth, grayscale): + """ + Take an RGB image and depth map and place them side by side. This includes a proper normalization of the depth map + for better visibility. + + Args: + image: the RGB image + depth: the depth map + grayscale: use a grayscale colormap? + + Returns: + the image and depth map place side by side + """ + depth_min = depth.min() + depth_max = depth.max() + normalized_depth = 255 * (depth - depth_min) / (depth_max - depth_min) + normalized_depth *= 3 + + right_side = np.repeat(np.expand_dims(normalized_depth, 2), 3, axis=2) / 3 + if not grayscale: + right_side = cv2.applyColorMap(np.uint8(right_side), cv2.COLORMAP_INFERNO) + + if image is None: + return right_side + else: + return np.concatenate((image, right_side), axis=1) + + +def run(input_path, output_path, model_path, model_type="dpt_beit_large_512", optimize=False, side=False, height=None, + square=False, grayscale=False): + """Run MonoDepthNN to compute depth maps. + + Args: + input_path (str): path to input folder + output_path (str): path to output folder + model_path (str): path to saved model + model_type (str): the model type + optimize (bool): optimize the model to half-floats on CUDA? + side (bool): RGB and depth side by side in output images? + height (int): inference encoder image height + square (bool): resize to a square resolution? + grayscale (bool): use a grayscale colormap? + """ + print("Initialize") + + # select device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print("Device: %s" % device) + + model, transform, net_w, net_h = load_model(device, model_path, model_type, optimize, height, square) + + # get input + if input_path is not None: + image_names = glob.glob(os.path.join(input_path, "*")) + num_images = len(image_names) + else: + print("No input path specified. Grabbing images from camera.") + + # create output folder + if output_path is not None: + os.makedirs(output_path, exist_ok=True) + + print("Start processing") + + if input_path is not None: + if output_path is None: + print("Warning: No output path specified. Images will be processed but not shown or stored anywhere.") + for index, image_name in enumerate(image_names): + + print(" Processing {} ({}/{})".format(image_name, index + 1, num_images)) + + # input + original_image_rgb = utils.read_image(image_name) # in [0, 1] + image = transform({"image": original_image_rgb})["image"] + + # compute + with torch.no_grad(): + prediction = process(device, model, model_type, image, (net_w, net_h), original_image_rgb.shape[1::-1], + optimize, False) + + # output + if output_path is not None: + filename = os.path.join( + output_path, os.path.splitext(os.path.basename(image_name))[0] + '-' + model_type + ) + if not side: + utils.write_depth(filename, prediction, grayscale, bits=2) + else: + original_image_bgr = np.flip(original_image_rgb, 2) + content = create_side_by_side(original_image_bgr*255, prediction, grayscale) + cv2.imwrite(filename + ".png", content) + utils.write_pfm(filename + ".pfm", prediction.astype(np.float32)) + + else: + with torch.no_grad(): + fps = 1 + video = VideoStream(0).start() + time_start = time.time() + frame_index = 0 + while True: + frame = video.read() + if frame is not None: + original_image_rgb = np.flip(frame, 2) # in [0, 255] (flip required to get RGB) + image = transform({"image": original_image_rgb/255})["image"] + + prediction = process(device, model, model_type, image, (net_w, net_h), + original_image_rgb.shape[1::-1], optimize, True) + + original_image_bgr = np.flip(original_image_rgb, 2) if side else None + content = create_side_by_side(original_image_bgr, prediction, grayscale) + cv2.imshow('MiDaS Depth Estimation - Press Escape to close window ', content/255) + + if output_path is not None: + filename = os.path.join(output_path, 'Camera' + '-' + model_type + '_' + str(frame_index)) + cv2.imwrite(filename + ".png", content) + + alpha = 0.1 + if time.time()-time_start > 0: + fps = (1 - alpha) * fps + alpha * 1 / (time.time()-time_start) # exponential moving average + time_start = time.time() + print(f"\rFPS: {round(fps,2)}", end="") + + if cv2.waitKey(1) == 27: # Escape key + break + + frame_index += 1 + print() + + print("Finished") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument('-i', '--input_path', + default=None, + help='Folder with input images (if no input path is specified, images are tried to be grabbed ' + 'from camera)' + ) + + parser.add_argument('-o', '--output_path', + default=None, + help='Folder for output images' + ) + + parser.add_argument('-m', '--model_weights', + default=None, + help='Path to the trained weights of model' + ) + + parser.add_argument('-t', '--model_type', + default='dpt_beit_large_512', + help='Model type: ' + 'dpt_beit_large_512, dpt_beit_large_384, dpt_beit_base_384, dpt_swin2_large_384, ' + 'dpt_swin2_base_384, dpt_swin2_tiny_256, dpt_swin_large_384, dpt_next_vit_large_384, ' + 'dpt_levit_224, dpt_large_384, dpt_hybrid_384, midas_v21_384, midas_v21_small_256 or ' + 'openvino_midas_v21_small_256' + ) + + parser.add_argument('-s', '--side', + action='store_true', + help='Output images contain RGB and depth images side by side' + ) + + parser.add_argument('--optimize', dest='optimize', action='store_true', help='Use half-float optimization') + parser.set_defaults(optimize=False) + + parser.add_argument('--height', + type=int, default=None, + help='Preferred height of images feed into the encoder during inference. Note that the ' + 'preferred height may differ from the actual height, because an alignment to multiples of ' + '32 takes place. Many models support only the height chosen during training, which is ' + 'used automatically if this parameter is not set.' + ) + parser.add_argument('--square', + action='store_true', + help='Option to resize images to a square resolution by changing their widths when images are ' + 'fed into the encoder during inference. If this parameter is not set, the aspect ratio of ' + 'images is tried to be preserved if supported by the model.' + ) + parser.add_argument('--grayscale', + action='store_true', + help='Use a grayscale colormap instead of the inferno one. Although the inferno colormap, ' + 'which is used by default, is better for visibility, it does not allow storing 16-bit ' + 'depth values in PNGs but only 8-bit ones due to the precision limitation of this ' + 'colormap.' + ) + + args = parser.parse_args() + + + if args.model_weights is None: + args.model_weights = default_models[args.model_type] + + # set torch options + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + + # compute depth maps + run(args.input_path, args.output_path, args.model_weights, args.model_type, args.optimize, args.side, args.height, + args.square, args.grayscale) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/README.md b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5b5fe0e63668eab45a55b140826cb3762862b17c --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/README.md @@ -0,0 +1,147 @@ +## Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer + +### TensorFlow inference using `.pb` and `.onnx` models + +1. [Run inference on TensorFlow-model by using TensorFlow](#run-inference-on-tensorflow-model-by-using-tensorFlow) + +2. [Run inference on ONNX-model by using TensorFlow](#run-inference-on-onnx-model-by-using-tensorflow) + +3. [Make ONNX model from downloaded Pytorch model file](#make-onnx-model-from-downloaded-pytorch-model-file) + + +### Run inference on TensorFlow-model by using TensorFlow + +1) Download the model weights [model-f6b98070.pb](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.pb) +and [model-small.pb](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small.pb) and place the +file in the `/tf/` folder. + +2) Set up dependencies: + +```shell +# install OpenCV +pip install --upgrade pip +pip install opencv-python + +# install TensorFlow +pip install -I grpcio tensorflow==2.3.0 tensorflow-addons==0.11.2 numpy==1.18.0 +``` + +#### Usage + +1) Place one or more input images in the folder `tf/input`. + +2) Run the model: + + ```shell + python tf/run_pb.py + ``` + + Or run the small model: + + ```shell + python tf/run_pb.py --model_weights model-small.pb --model_type small + ``` + +3) The resulting inverse depth maps are written to the `tf/output` folder. + + +### Run inference on ONNX-model by using ONNX-Runtime + +1) Download the model weights [model-f6b98070.onnx](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.onnx) +and [model-small.onnx](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small.onnx) and place the +file in the `/tf/` folder. + +2) Set up dependencies: + +```shell +# install OpenCV +pip install --upgrade pip +pip install opencv-python + +# install ONNX +pip install onnx==1.7.0 + +# install ONNX Runtime +pip install onnxruntime==1.5.2 +``` + +#### Usage + +1) Place one or more input images in the folder `tf/input`. + +2) Run the model: + + ```shell + python tf/run_onnx.py + ``` + + Or run the small model: + + ```shell + python tf/run_onnx.py --model_weights model-small.onnx --model_type small + ``` + +3) The resulting inverse depth maps are written to the `tf/output` folder. + + + +### Make ONNX model from downloaded Pytorch model file + +1) Download the model weights [model-f6b98070.pt](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.pt) and place the +file in the root folder. + +2) Set up dependencies: + +```shell +# install OpenCV +pip install --upgrade pip +pip install opencv-python + +# install PyTorch TorchVision +pip install -I torch==1.7.0 torchvision==0.8.0 + +# install TensorFlow +pip install -I grpcio tensorflow==2.3.0 tensorflow-addons==0.11.2 numpy==1.18.0 + +# install ONNX +pip install onnx==1.7.0 + +# install ONNX-TensorFlow +git clone https://github.com/onnx/onnx-tensorflow.git +cd onnx-tensorflow +git checkout 095b51b88e35c4001d70f15f80f31014b592b81e +pip install -e . +``` + +#### Usage + +1) Run the converter: + + ```shell + python tf/make_onnx_model.py + ``` + +2) The resulting `model-f6b98070.onnx` file is written to the `/tf/` folder. + + +### Requirements + + The code was tested with Python 3.6.9, PyTorch 1.5.1, TensorFlow 2.2.0, TensorFlow-addons 0.8.3, ONNX 1.7.0, ONNX-TensorFlow (GitHub-master-17.07.2020) and OpenCV 4.3.0. + +### Citation + +Please cite our paper if you use this code or any of the models: +``` +@article{Ranftl2019, + author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun}, + title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer}, + journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, + year = {2020}, +} +``` + +### License + +MIT License + + diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/input/.placeholder b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/input/.placeholder new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/make_onnx_model.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/make_onnx_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d14b0e4e1d2ea70fa315fd7ca7dfd72440a19376 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/make_onnx_model.py @@ -0,0 +1,112 @@ +"""Compute depth maps for images in the input folder. +""" +import os +import ntpath +import glob +import torch +import utils +import cv2 +import numpy as np +from torchvision.transforms import Compose, Normalize +from torchvision import transforms + +from shutil import copyfile +import fileinput +import sys +sys.path.append(os.getcwd() + '/..') + +def modify_file(): + modify_filename = '../midas/blocks.py' + copyfile(modify_filename, modify_filename+'.bak') + + with open(modify_filename, 'r') as file : + filedata = file.read() + + filedata = filedata.replace('align_corners=True', 'align_corners=False') + filedata = filedata.replace('import torch.nn as nn', 'import torch.nn as nn\nimport torchvision.models as models') + filedata = filedata.replace('torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")', 'models.resnext101_32x8d()') + + with open(modify_filename, 'w') as file: + file.write(filedata) + +def restore_file(): + modify_filename = '../midas/blocks.py' + copyfile(modify_filename+'.bak', modify_filename) + +modify_file() + +from midas.midas_net import MidasNet +from midas.transforms import Resize, NormalizeImage, PrepareForNet + +restore_file() + + +class MidasNet_preprocessing(MidasNet): + """Network for monocular depth estimation. + """ + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + mean = torch.tensor([0.485, 0.456, 0.406]) + std = torch.tensor([0.229, 0.224, 0.225]) + x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) + + return MidasNet.forward(self, x) + + +def run(model_path): + """Run MonoDepthNN to compute depth maps. + + Args: + model_path (str): path to saved model + """ + print("initialize") + + # select device + + # load network + #model = MidasNet(model_path, non_negative=True) + model = MidasNet_preprocessing(model_path, non_negative=True) + + model.eval() + + print("start processing") + + # input + img_input = np.zeros((3, 384, 384), np.float32) + + # compute + with torch.no_grad(): + sample = torch.from_numpy(img_input).unsqueeze(0) + prediction = model.forward(sample) + prediction = ( + torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=img_input.shape[:2], + mode="bicubic", + align_corners=False, + ) + .squeeze() + .cpu() + .numpy() + ) + + torch.onnx.export(model, sample, ntpath.basename(model_path).rsplit('.', 1)[0]+'.onnx', opset_version=9) + + print("finished") + + +if __name__ == "__main__": + # set paths + # MODEL_PATH = "model.pt" + MODEL_PATH = "../model-f6b98070.pt" + + # compute depth maps + run(MODEL_PATH) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/output/.placeholder b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/output/.placeholder new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_onnx.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..7107b99969a127f951814f743d5c562a436b2430 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_onnx.py @@ -0,0 +1,119 @@ +"""Compute depth maps for images in the input folder. +""" +import os +import glob +import utils +import cv2 +import sys +import numpy as np +import argparse + +import onnx +import onnxruntime as rt + +from transforms import Resize, NormalizeImage, PrepareForNet + + +def run(input_path, output_path, model_path, model_type="large"): + """Run MonoDepthNN to compute depth maps. + + Args: + input_path (str): path to input folder + output_path (str): path to output folder + model_path (str): path to saved model + """ + print("initialize") + + # select device + device = "CUDA:0" + #device = "CPU" + print("device: %s" % device) + + # network resolution + if model_type == "large": + net_w, net_h = 384, 384 + elif model_type == "small": + net_w, net_h = 256, 256 + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + # load network + print("loading model...") + model = rt.InferenceSession(model_path) + input_name = model.get_inputs()[0].name + output_name = model.get_outputs()[0].name + + resize_image = Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ) + + def compose2(f1, f2): + return lambda x: f2(f1(x)) + + transform = compose2(resize_image, PrepareForNet()) + + # get input + img_names = glob.glob(os.path.join(input_path, "*")) + num_images = len(img_names) + + # create output folder + os.makedirs(output_path, exist_ok=True) + + print("start processing") + + for ind, img_name in enumerate(img_names): + + print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) + + # input + img = utils.read_image(img_name) + img_input = transform({"image": img})["image"] + + # compute + output = model.run([output_name], {input_name: img_input.reshape(1, 3, net_h, net_w).astype(np.float32)})[0] + prediction = np.array(output).reshape(net_h, net_w) + prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) + + # output + filename = os.path.join( + output_path, os.path.splitext(os.path.basename(img_name))[0] + ) + utils.write_depth(filename, prediction, bits=2) + + print("finished") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument('-i', '--input_path', + default='input', + help='folder with input images' + ) + + parser.add_argument('-o', '--output_path', + default='output', + help='folder for output images' + ) + + parser.add_argument('-m', '--model_weights', + default='model-f6b98070.onnx', + help='path to the trained weights of model' + ) + + parser.add_argument('-t', '--model_type', + default='large', + help='model type: large or small' + ) + + args = parser.parse_args() + + # compute depth maps + run(args.input_path, args.output_path, args.model_weights, args.model_type) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_pb.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_pb.py new file mode 100644 index 0000000000000000000000000000000000000000..e46254f7b37f72e7d87672d70fd4b2f393ad7658 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_pb.py @@ -0,0 +1,135 @@ +"""Compute depth maps for images in the input folder. +""" +import os +import glob +import utils +import cv2 +import argparse + +import tensorflow as tf + +from transforms import Resize, NormalizeImage, PrepareForNet + +def run(input_path, output_path, model_path, model_type="large"): + """Run MonoDepthNN to compute depth maps. + + Args: + input_path (str): path to input folder + output_path (str): path to output folder + model_path (str): path to saved model + """ + print("initialize") + + # the runtime initialization will not allocate all memory on the device to avoid out of GPU memory + gpus = tf.config.experimental.list_physical_devices('GPU') + if gpus: + try: + for gpu in gpus: + #tf.config.experimental.set_memory_growth(gpu, True) + tf.config.experimental.set_virtual_device_configuration(gpu, + [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4000)]) + except RuntimeError as e: + print(e) + + # network resolution + if model_type == "large": + net_w, net_h = 384, 384 + elif model_type == "small": + net_w, net_h = 256, 256 + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + # load network + graph_def = tf.compat.v1.GraphDef() + with tf.io.gfile.GFile(model_path, 'rb') as f: + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + + + model_operations = tf.compat.v1.get_default_graph().get_operations() + input_node = '0:0' + output_layer = model_operations[len(model_operations) - 1].name + ':0' + print("Last layer name: ", output_layer) + + resize_image = Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ) + + def compose2(f1, f2): + return lambda x: f2(f1(x)) + + transform = compose2(resize_image, PrepareForNet()) + + # get input + img_names = glob.glob(os.path.join(input_path, "*")) + num_images = len(img_names) + + # create output folder + os.makedirs(output_path, exist_ok=True) + + print("start processing") + + with tf.compat.v1.Session() as sess: + try: + # load images + for ind, img_name in enumerate(img_names): + + print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) + + # input + img = utils.read_image(img_name) + img_input = transform({"image": img})["image"] + + # compute + prob_tensor = sess.graph.get_tensor_by_name(output_layer) + prediction, = sess.run(prob_tensor, {input_node: [img_input] }) + prediction = prediction.reshape(net_h, net_w) + prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) + + # output + filename = os.path.join( + output_path, os.path.splitext(os.path.basename(img_name))[0] + ) + utils.write_depth(filename, prediction, bits=2) + + except KeyError: + print ("Couldn't find input node: ' + input_node + ' or output layer: " + output_layer + ".") + exit(-1) + + print("finished") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument('-i', '--input_path', + default='input', + help='folder with input images' + ) + + parser.add_argument('-o', '--output_path', + default='output', + help='folder for output images' + ) + + parser.add_argument('-m', '--model_weights', + default='model-f6b98070.pb', + help='path to the trained weights of model' + ) + + parser.add_argument('-t', '--model_type', + default='large', + help='model type: large or small' + ) + + args = parser.parse_args() + + # compute depth maps + run(args.input_path, args.output_path, args.model_weights, args.model_type) diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/transforms.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/utils.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ff9a54bd55f5e31a90fad21242efbfda5a6cc1a7 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/utils.py @@ -0,0 +1,82 @@ +import numpy as np +import sys +import cv2 + + +def write_pfm(path, image, scale=1): + """Write pfm file. + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + +def read_image(path): + """Read image and output RGB image (0-1). + Args: + path (str): path to file + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = 0 + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/utils.py b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3976fd97dfe6a9dc7d4fa144be8fcb0b18b2db --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/utils.py @@ -0,0 +1,199 @@ +"""Utils for monoDepth. +""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, grayscale, bits=1): + """Write depth map to png file. + + Args: + path (str): filepath without extension + depth (array): depth + grayscale (bool): use a grayscale colormap? + """ + if not grayscale: + bits = 1 + + if not np.isfinite(depth).all(): + depth=np.nan_to_num(depth, nan=0.0, posinf=0.0, neginf=0.0) + print("WARNING: Non-finite depth values present") + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.dtype) + + if not grayscale: + out = cv2.applyColorMap(np.uint8(out), cv2.COLORMAP_INFERNO) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/src/flux/annotator/zoe/zoedepth/models/builder.py b/src/flux/annotator/zoe/zoedepth/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..0818311b642561712a03a66655c638ce09a04cca --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/builder.py @@ -0,0 +1,51 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from importlib import import_module +from .depth_model import DepthModel + +def build_model(config) -> DepthModel: + """Builds a model from a config. The model is specified by the model name and version in the config. The model is then constructed using the build_from_config function of the model interface. + This function should be used to construct models for training and evaluation. + + Args: + config (dict): Config dict. Config is constructed in utils/config.py. Each model has its own config file(s) saved in its root model folder. + + Returns: + torch.nn.Module: Model corresponding to name and version as specified in config + """ + module_name = f"zoedepth.models.{config.model}" + try: + module = import_module(module_name) + except ModuleNotFoundError as e: + # print the original error message + print(e) + raise ValueError( + f"Model {config.model} not found. Refer above error for details.") from e + try: + get_version = getattr(module, "get_version") + except AttributeError as e: + raise ValueError( + f"Model {config.model} has no get_version function.") from e + return get_version(config.version_name).build_from_config(config) diff --git a/src/flux/annotator/zoe/zoedepth/models/depth_model.py b/src/flux/annotator/zoe/zoedepth/models/depth_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fc421c108ea3928c9add62b4c190500d9bd4eda1 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/depth_model.py @@ -0,0 +1,152 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +import PIL.Image +from PIL import Image +from typing import Union + + +class DepthModel(nn.Module): + def __init__(self): + super().__init__() + self.device = 'cpu' + + def to(self, device) -> nn.Module: + self.device = device + return super().to(device) + + def forward(self, x, *args, **kwargs): + raise NotImplementedError + + def _infer(self, x: torch.Tensor): + """ + Inference interface for the model + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + return self(x)['metric_depth'] + + def _infer_with_pad_aug(self, x: torch.Tensor, pad_input: bool=True, fh: float=3, fw: float=3, upsampling_mode: str='bicubic', padding_mode="reflect", **kwargs) -> torch.Tensor: + """ + Inference interface for the model with padding augmentation + Padding augmentation fixes the boundary artifacts in the output depth map. + Boundary artifacts are sometimes caused by the fact that the model is trained on NYU raw dataset which has a black or white border around the image. + This augmentation pads the input image and crops the prediction back to the original size / view. + + Note: This augmentation is not required for the models trained with 'avoid_boundary'=True. + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to pad the input or not. Defaults to True. + fh (float, optional): height padding factor. The padding is calculated as sqrt(h/2) * fh. Defaults to 3. + fw (float, optional): width padding factor. The padding is calculated as sqrt(w/2) * fw. Defaults to 3. + upsampling_mode (str, optional): upsampling mode. Defaults to 'bicubic'. + padding_mode (str, optional): padding mode. Defaults to "reflect". + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + # assert x is nchw and c = 3 + assert x.dim() == 4, "x must be 4 dimensional, got {}".format(x.dim()) + assert x.shape[1] == 3, "x must have 3 channels, got {}".format(x.shape[1]) + + if pad_input: + assert fh > 0 or fw > 0, "atlease one of fh and fw must be greater than 0" + pad_h = int(np.sqrt(x.shape[2]/2) * fh) + pad_w = int(np.sqrt(x.shape[3]/2) * fw) + padding = [pad_w, pad_w] + if pad_h > 0: + padding += [pad_h, pad_h] + + x = F.pad(x, padding, mode=padding_mode, **kwargs) + out = self._infer(x) + if out.shape[-2:] != x.shape[-2:]: + out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode=upsampling_mode, align_corners=False) + if pad_input: + # crop to the original size, handling the case where pad_h and pad_w is 0 + if pad_h > 0: + out = out[:, :, pad_h:-pad_h,:] + if pad_w > 0: + out = out[:, :, :, pad_w:-pad_w] + return out + + def infer_with_flip_aug(self, x, pad_input: bool=True, **kwargs) -> torch.Tensor: + """ + Inference interface for the model with horizontal flip augmentation + Horizontal flip augmentation improves the accuracy of the model by averaging the output of the model with and without horizontal flip. + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + # infer with horizontal flip and average + out = self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs) + out_flip = self._infer_with_pad_aug(torch.flip(x, dims=[3]), pad_input=pad_input, **kwargs) + out = (out + torch.flip(out_flip, dims=[3])) / 2 + return out + + def infer(self, x, pad_input: bool=True, with_flip_aug: bool=True, **kwargs) -> torch.Tensor: + """ + Inference interface for the model + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True. + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + if with_flip_aug: + return self.infer_with_flip_aug(x, pad_input=pad_input, **kwargs) + else: + return self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs) + + @torch.no_grad() + def infer_pil(self, pil_img, pad_input: bool=True, with_flip_aug: bool=True, output_type: str="numpy", **kwargs) -> Union[np.ndarray, PIL.Image.Image, torch.Tensor]: + """ + Inference interface for the model for PIL image + Args: + pil_img (PIL.Image.Image): input PIL image + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True. + output_type (str, optional): output type. Supported values are 'numpy', 'pil' and 'tensor'. Defaults to "numpy". + """ + x = transforms.ToTensor()(pil_img).unsqueeze(0).to(self.device) + out_tensor = self.infer(x, pad_input=pad_input, with_flip_aug=with_flip_aug, **kwargs) + if output_type == "numpy": + return out_tensor.squeeze().cpu().numpy() + elif output_type == "pil": + # uint16 is required for depth pil image + out_16bit_numpy = (out_tensor.squeeze().cpu().numpy()*256).astype(np.uint16) + return Image.fromarray(out_16bit_numpy) + elif output_type == "tensor": + return out_tensor.squeeze().cpu() + else: + raise ValueError(f"output_type {output_type} not supported. Supported values are 'numpy', 'pil' and 'tensor'") + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/layers/attractor.py b/src/flux/annotator/zoe/zoedepth/models/layers/attractor.py new file mode 100644 index 0000000000000000000000000000000000000000..2a8efe645adea1d88a12e2ac5cc6bb2a251eef9d --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/layers/attractor.py @@ -0,0 +1,208 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +@torch.jit.script +def exp_attractor(dx, alpha: float = 300, gamma: int = 2): + """Exponential attractor: dc = exp(-alpha*|dx|^gamma) * dx , where dx = a - c, a = attractor point, c = bin center, dc = shift in bin centermmary for exp_attractor + + Args: + dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. + alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. + gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. + + Returns: + torch.Tensor : Delta shifts - dc; New bin centers = Old bin centers + dc + """ + return torch.exp(-alpha*(torch.abs(dx)**gamma)) * (dx) + + +@torch.jit.script +def inv_attractor(dx, alpha: float = 300, gamma: int = 2): + """Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center + This is the default one according to the accompanying paper. + + Args: + dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. + alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. + gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. + + Returns: + torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc + """ + return dx.div(1+alpha*dx.pow(gamma)) + + +class AttractorLayer(nn.Module): + def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, + alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): + """ + Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth) + """ + super().__init__() + + self.n_attractors = n_attractors + self.n_bins = n_bins + self.min_depth = min_depth + self.max_depth = max_depth + self.alpha = alpha + self.gamma = gamma + self.kind = kind + self.attractor_type = attractor_type + self.memory_efficient = memory_efficient + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_attractors*2, 1, 1, 0), # x2 for linear norm + nn.ReLU(inplace=True) + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + Args: + x (torch.Tensor) : feature block; shape - n, c, h, w + b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w + + Returns: + tuple(torch.Tensor,torch.Tensor) : new bin centers normed and scaled; shape - n, nbins, h, w + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate( + prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + + A = self._net(x) + eps = 1e-3 + A = A + eps + n, c, h, w = A.shape + A = A.view(n, self.n_attractors, 2, h, w) + A_normed = A / A.sum(dim=2, keepdim=True) # n, a, 2, h, w + A_normed = A[:, :, 0, ...] # n, na, h, w + + b_prev = nn.functional.interpolate( + b_prev, (h, w), mode='bilinear', align_corners=True) + b_centers = b_prev + + if self.attractor_type == 'exp': + dist = exp_attractor + else: + dist = inv_attractor + + if not self.memory_efficient: + func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] + # .shape N, nbins, h, w + delta_c = func(dist(A_normed.unsqueeze( + 2) - b_centers.unsqueeze(1)), dim=1) + else: + delta_c = torch.zeros_like(b_centers, device=b_centers.device) + for i in range(self.n_attractors): + # .shape N, nbins, h, w + delta_c += dist(A_normed[:, i, ...].unsqueeze(1) - b_centers) + + if self.kind == 'mean': + delta_c = delta_c / self.n_attractors + + b_new_centers = b_centers + delta_c + B_centers = (self.max_depth - self.min_depth) * \ + b_new_centers + self.min_depth + B_centers, _ = torch.sort(B_centers, dim=1) + B_centers = torch.clip(B_centers, self.min_depth, self.max_depth) + return b_new_centers, B_centers + + +class AttractorLayerUnnormed(nn.Module): + def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, + alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): + """ + Attractor layer for bin centers. Bin centers are unbounded + """ + super().__init__() + + self.n_attractors = n_attractors + self.n_bins = n_bins + self.min_depth = min_depth + self.max_depth = max_depth + self.alpha = alpha + self.gamma = gamma + self.kind = kind + self.attractor_type = attractor_type + self.memory_efficient = memory_efficient + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0), + nn.Softplus() + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + Args: + x (torch.Tensor) : feature block; shape - n, c, h, w + b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w + + Returns: + tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate( + prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + + A = self._net(x) + n, c, h, w = A.shape + + b_prev = nn.functional.interpolate( + b_prev, (h, w), mode='bilinear', align_corners=True) + b_centers = b_prev + + if self.attractor_type == 'exp': + dist = exp_attractor + else: + dist = inv_attractor + + if not self.memory_efficient: + func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] + # .shape N, nbins, h, w + delta_c = func( + dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1) + else: + delta_c = torch.zeros_like(b_centers, device=b_centers.device) + for i in range(self.n_attractors): + delta_c += dist(A[:, i, ...].unsqueeze(1) - + b_centers) # .shape N, nbins, h, w + + if self.kind == 'mean': + delta_c = delta_c / self.n_attractors + + b_new_centers = b_centers + delta_c + B_centers = b_new_centers + + return b_new_centers, B_centers diff --git a/src/flux/annotator/zoe/zoedepth/models/layers/dist_layers.py b/src/flux/annotator/zoe/zoedepth/models/layers/dist_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3208405dfb78fdfc28d5765e5a6d5dbe31967a23 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/layers/dist_layers.py @@ -0,0 +1,121 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +def log_binom(n, k, eps=1e-7): + """ log(nCk) using stirling approximation """ + n = n + eps + k = k + eps + return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps) + + +class LogBinomial(nn.Module): + def __init__(self, n_classes=256, act=torch.softmax): + """Compute log binomial distribution for n_classes + + Args: + n_classes (int, optional): number of output classes. Defaults to 256. + """ + super().__init__() + self.K = n_classes + self.act = act + self.register_buffer('k_idx', torch.arange( + 0, n_classes).view(1, -1, 1, 1)) + self.register_buffer('K_minus_1', torch.Tensor( + [self.K-1]).view(1, -1, 1, 1)) + + def forward(self, x, t=1., eps=1e-4): + """Compute log binomial distribution for x + + Args: + x (torch.Tensor - NCHW): probabilities + t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1.. + eps (float, optional): Small number for numerical stability. Defaults to 1e-4. + + Returns: + torch.Tensor -NCHW: log binomial distribution logbinomial(p;t) + """ + if x.ndim == 3: + x = x.unsqueeze(1) # make it nchw + + one_minus_x = torch.clamp(1 - x, eps, 1) + x = torch.clamp(x, eps, 1) + y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \ + torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x) + return self.act(y/t, dim=1) + + +class ConditionalLogBinomial(nn.Module): + def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax): + """Conditional Log Binomial distribution + + Args: + in_features (int): number of input channels in main feature + condition_dim (int): number of input channels in condition feature + n_classes (int, optional): Number of classes. Defaults to 256. + bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2. + p_eps (float, optional): small eps value. Defaults to 1e-4. + max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50. + min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7. + """ + super().__init__() + self.p_eps = p_eps + self.max_temp = max_temp + self.min_temp = min_temp + self.log_binomial_transform = LogBinomial(n_classes, act=act) + bottleneck = (in_features + condition_dim) // bottleneck_factor + self.mlp = nn.Sequential( + nn.Conv2d(in_features + condition_dim, bottleneck, + kernel_size=1, stride=1, padding=0), + nn.GELU(), + # 2 for p linear norm, 2 for t linear norm + nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0), + nn.Softplus() + ) + + def forward(self, x, cond): + """Forward pass + + Args: + x (torch.Tensor - NCHW): Main feature + cond (torch.Tensor - NCHW): condition feature + + Returns: + torch.Tensor: Output log binomial distribution + """ + pt = self.mlp(torch.concat((x, cond), dim=1)) + p, t = pt[:, :2, ...], pt[:, 2:, ...] + + p = p + self.p_eps + p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...]) + + t = t + self.p_eps + t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...]) + t = t.unsqueeze(1) + t = (self.max_temp - self.min_temp) * t + self.min_temp + + return self.log_binomial_transform(p, t) diff --git a/src/flux/annotator/zoe/zoedepth/models/layers/localbins_layers.py b/src/flux/annotator/zoe/zoedepth/models/layers/localbins_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f94481605c3e6958ce50e73b2eb31d9f0c07dc67 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/layers/localbins_layers.py @@ -0,0 +1,169 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +class SeedBinRegressor(nn.Module): + def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): + """Bin center regressor network. Bin centers are bounded on (min_depth, max_depth) interval. + + Args: + in_features (int): input channels + n_bins (int, optional): Number of bin centers. Defaults to 16. + mlp_dim (int, optional): Hidden dimension. Defaults to 256. + min_depth (float, optional): Min depth value. Defaults to 1e-3. + max_depth (float, optional): Max depth value. Defaults to 10. + """ + super().__init__() + self.version = "1_1" + self.min_depth = min_depth + self.max_depth = max_depth + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + """ + Returns tensor of bin_width vectors (centers). One vector b for every pixel + """ + B = self._net(x) + eps = 1e-3 + B = B + eps + B_widths_normed = B / B.sum(dim=1, keepdim=True) + B_widths = (self.max_depth - self.min_depth) * \ + B_widths_normed # .shape NCHW + # pad has the form (left, right, top, bottom, front, back) + B_widths = nn.functional.pad( + B_widths, (0, 0, 0, 0, 1, 0), mode='constant', value=self.min_depth) + B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW + + B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:, 1:, ...]) + return B_widths_normed, B_centers + + +class SeedBinRegressorUnnormed(nn.Module): + def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): + """Bin center regressor network. Bin centers are unbounded + + Args: + in_features (int): input channels + n_bins (int, optional): Number of bin centers. Defaults to 16. + mlp_dim (int, optional): Hidden dimension. Defaults to 256. + min_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) + max_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) + """ + super().__init__() + self.version = "1_1" + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), + nn.Softplus() + ) + + def forward(self, x): + """ + Returns tensor of bin_width vectors (centers). One vector b for every pixel + """ + B_centers = self._net(x) + return B_centers, B_centers + + +class Projector(nn.Module): + def __init__(self, in_features, out_features, mlp_dim=128): + """Projector MLP + + Args: + in_features (int): input channels + out_features (int): output channels + mlp_dim (int, optional): hidden dimension. Defaults to 128. + """ + super().__init__() + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, out_features, 1, 1, 0), + ) + + def forward(self, x): + return self._net(x) + + + +class LinearSplitter(nn.Module): + def __init__(self, in_features, prev_nbins, split_factor=2, mlp_dim=128, min_depth=1e-3, max_depth=10): + super().__init__() + + self.prev_nbins = prev_nbins + self.split_factor = split_factor + self.min_depth = min_depth + self.max_depth = max_depth + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.GELU(), + nn.Conv2d(mlp_dim, prev_nbins * split_factor, 1, 1, 0), + nn.ReLU() + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + x : feature block; shape - n, c, h, w + b_prev : previous bin widths normed; shape - n, prev_nbins, h, w + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate(prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + S = self._net(x) + eps = 1e-3 + S = S + eps + n, c, h, w = S.shape + S = S.view(n, self.prev_nbins, self.split_factor, h, w) + S_normed = S / S.sum(dim=2, keepdim=True) # fractional splits + + b_prev = nn.functional.interpolate(b_prev, (h,w), mode='bilinear', align_corners=True) + + + b_prev = b_prev / b_prev.sum(dim=1, keepdim=True) # renormalize for gurantees + # print(b_prev.shape, S_normed.shape) + # if is_for_query:(1).expand(-1, b_prev.size(0)//n, -1, -1, -1, -1).flatten(0,1) # TODO ? can replace all this with a single torch.repeat? + b = b_prev.unsqueeze(2) * S_normed + b = b.flatten(1,2) # .shape n, prev_nbins * split_factor, h, w + + # calculate bin centers for loss calculation + B_widths = (self.max_depth - self.min_depth) * b # .shape N, nprev * splitfactor, H, W + # pad has the form (left, right, top, bottom, front, back) + B_widths = nn.functional.pad(B_widths, (0,0,0,0,1,0), mode='constant', value=self.min_depth) + B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW + + B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:,1:,...]) + return b, B_centers \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/layers/patch_transformer.py b/src/flux/annotator/zoe/zoedepth/models/layers/patch_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..99d9e51a06b981bae45ce7dd64eaef19a4121991 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/layers/patch_transformer.py @@ -0,0 +1,91 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +class PatchTransformerEncoder(nn.Module): + def __init__(self, in_channels, patch_size=10, embedding_dim=128, num_heads=4, use_class_token=False): + """ViT-like transformer block + + Args: + in_channels (int): Input channels + patch_size (int, optional): patch size. Defaults to 10. + embedding_dim (int, optional): Embedding dimension in transformer model. Defaults to 128. + num_heads (int, optional): number of attention heads. Defaults to 4. + use_class_token (bool, optional): Whether to use extra token at the start for global accumulation (called as "class token"). Defaults to False. + """ + super(PatchTransformerEncoder, self).__init__() + self.use_class_token = use_class_token + encoder_layers = nn.TransformerEncoderLayer( + embedding_dim, num_heads, dim_feedforward=1024) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layers, num_layers=4) # takes shape S,N,E + + self.embedding_convPxP = nn.Conv2d(in_channels, embedding_dim, + kernel_size=patch_size, stride=patch_size, padding=0) + + def positional_encoding_1d(self, sequence_length, batch_size, embedding_dim, device='cpu'): + """Generate positional encodings + + Args: + sequence_length (int): Sequence length + embedding_dim (int): Embedding dimension + + Returns: + torch.Tensor SBE: Positional encodings + """ + position = torch.arange( + 0, sequence_length, dtype=torch.float32, device=device).unsqueeze(1) + index = torch.arange( + 0, embedding_dim, 2, dtype=torch.float32, device=device).unsqueeze(0) + div_term = torch.exp(index * (-torch.log(torch.tensor(10000.0, device=device)) / embedding_dim)) + pos_encoding = position * div_term + pos_encoding = torch.cat([torch.sin(pos_encoding), torch.cos(pos_encoding)], dim=1) + pos_encoding = pos_encoding.unsqueeze(1).repeat(1, batch_size, 1) + return pos_encoding + + + def forward(self, x): + """Forward pass + + Args: + x (torch.Tensor - NCHW): Input feature tensor + + Returns: + torch.Tensor - SNE: Transformer output embeddings. S - sequence length (=HW/patch_size^2), N - batch size, E - embedding dim + """ + embeddings = self.embedding_convPxP(x).flatten( + 2) # .shape = n,c,s = n, embedding_dim, s + if self.use_class_token: + # extra special token at start ? + embeddings = nn.functional.pad(embeddings, (1, 0)) + + # change to S,N,E format required by transformer + embeddings = embeddings.permute(2, 0, 1) + S, N, E = embeddings.shape + embeddings = embeddings + self.positional_encoding_1d(S, N, E, device=embeddings.device) + x = self.transformer_encoder(embeddings) # .shape = S, N, E + return x diff --git a/src/flux/annotator/zoe/zoedepth/models/model_io.py b/src/flux/annotator/zoe/zoedepth/models/model_io.py new file mode 100644 index 0000000000000000000000000000000000000000..78b6579631dd847ac76651238cb5a948b5a66286 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/model_io.py @@ -0,0 +1,92 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch + +def load_state_dict(model, state_dict): + """Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict. + + DataParallel prefixes state_dict keys with 'module.' when saving. + If the model is not a DataParallel model but the state_dict is, then prefixes are removed. + If the model is a DataParallel model but the state_dict is not, then prefixes are added. + """ + state_dict = state_dict.get('model', state_dict) + # if model is a DataParallel model, then state_dict keys are prefixed with 'module.' + + do_prefix = isinstance( + model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)) + state = {} + for k, v in state_dict.items(): + if k.startswith('module.') and not do_prefix: + k = k[7:] + + if not k.startswith('module.') and do_prefix: + k = 'module.' + k + + state[k] = v + + model.load_state_dict(state) + print("Loaded successfully") + return model + + +def load_wts(model, checkpoint_path): + ckpt = torch.load(checkpoint_path, map_location='cpu') + return load_state_dict(model, ckpt) + + +def load_state_dict_from_url(model, url, **kwargs): + state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs) + return load_state_dict(model, state_dict) + + +def load_state_from_resource(model, resource: str): + """Loads weights to the model from a given resource. A resource can be of following types: + 1. URL. Prefixed with "url::" + e.g. url::http(s)://url.resource.com/ckpt.pt + + 2. Local path. Prefixed with "local::" + e.g. local::/path/to/ckpt.pt + + + Args: + model (torch.nn.Module): Model + resource (str): resource string + + Returns: + torch.nn.Module: Model with loaded weights + """ + print(f"Using pretrained resource {resource}") + + if resource.startswith('url::'): + url = resource.split('url::')[1] + return load_state_dict_from_url(model, url, progress=True) + + elif resource.startswith('local::'): + path = resource.split('local::')[1] + return load_wts(model, path) + + else: + raise ValueError("Invalid resource type, only url:: and local:: are supported") + \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth/__init__.py b/src/flux/annotator/zoe/zoedepth/models/zoedepth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc33f737d238766559f0e3a8def3c0b568f23b7f --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth/__init__.py @@ -0,0 +1,31 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from .zoedepth_v1 import ZoeDepth + +all_versions = { + "v1": ZoeDepth, +} + +get_version = lambda v : all_versions[v] \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth.json b/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth.json new file mode 100644 index 0000000000000000000000000000000000000000..3112ed78c89f00e1d13f5d6e5be87cd3216b6dc7 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth.json @@ -0,0 +1,58 @@ +{ + "model": { + "name": "ZoeDepth", + "version_name": "v1", + "n_bins": 64, + "bin_embedding_dim": 128, + "bin_centers_type": "softplus", + "n_attractors":[16, 8, 4, 1], + "attractor_alpha": 1000, + "attractor_gamma": 2, + "attractor_kind" : "mean", + "attractor_type" : "inv", + "midas_model_type" : "DPT_BEiT_L_384", + "min_temp": 0.0212, + "max_temp": 50.0, + "output_distribution": "logbinomial", + "memory_efficient": true, + "inverse_midas": false, + "img_size": [384, 512] + }, + + "train": { + "train_midas": true, + "use_pretrained_midas": true, + "trainer": "zoedepth", + "epochs": 5, + "bs": 16, + "optim_kwargs": {"lr": 0.000161, "wd": 0.01}, + "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, + "same_lr": false, + "w_si": 1, + "w_domain": 0.2, + "w_reg": 0, + "w_grad": 0, + "avoid_boundary": false, + "random_crop": false, + "input_width": 640, + "input_height": 480, + "midas_lr_factor": 1, + "encoder_lr_factor":10, + "pos_enc_lr_factor":10, + "freeze_midas_bn": true + + }, + + "infer":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : null, + "force_keep_ar": true + }, + + "eval":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : null + } +} \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json b/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json new file mode 100644 index 0000000000000000000000000000000000000000..b51802aa44b91c39e15aacaac4b5ab6bec884414 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json @@ -0,0 +1,22 @@ +{ + "model": { + "bin_centers_type": "normed", + "img_size": [384, 768] + }, + + "train": { + }, + + "infer":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt", + "force_keep_ar": true + }, + + "eval":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt" + } +} \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth/zoedepth_v1.py b/src/flux/annotator/zoe/zoedepth/models/zoedepth/zoedepth_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..bc931b059d6165c84e8ff4f09d5c62d19930cee9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth/zoedepth_v1.py @@ -0,0 +1,250 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import itertools + +import torch +import torch.nn as nn +from ..depth_model import DepthModel +from ..base_models.midas import MidasCore +from ..layers.attractor import AttractorLayer, AttractorLayerUnnormed +from ..layers.dist_layers import ConditionalLogBinomial +from ..layers.localbins_layers import (Projector, SeedBinRegressor, + SeedBinRegressorUnnormed) +from ..model_io import load_state_from_resource + + +class ZoeDepth(DepthModel): + def __init__(self, core, n_bins=64, bin_centers_type="softplus", bin_embedding_dim=128, min_depth=1e-3, max_depth=10, + n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', min_temp=5, max_temp=50, train_midas=True, + midas_lr_factor=10, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs): + """ZoeDepth model. This is the version of ZoeDepth that has a single metric head + + Args: + core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features + n_bins (int, optional): Number of bin centers. Defaults to 64. + bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. + For "softplus", softplus activation is used and thus are unbounded. Defaults to "softplus". + bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. + min_depth (float, optional): Lower bound for normed bin centers. Defaults to 1e-3. + max_depth (float, optional): Upper bound for normed bin centers. Defaults to 10. + n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. + attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. + attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. + attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. + attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. + min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. + max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. + train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. + midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. + encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. + pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. + """ + super().__init__() + + self.core = core + self.max_depth = max_depth + self.min_depth = min_depth + self.min_temp = min_temp + self.bin_centers_type = bin_centers_type + + self.midas_lr_factor = midas_lr_factor + self.encoder_lr_factor = encoder_lr_factor + self.pos_enc_lr_factor = pos_enc_lr_factor + self.train_midas = train_midas + self.inverse_midas = inverse_midas + + if self.encoder_lr_factor <= 0: + self.core.freeze_encoder( + freeze_rel_pos=self.pos_enc_lr_factor <= 0) + + N_MIDAS_OUT = 32 + btlnck_features = self.core.output_channels[0] + num_out_features = self.core.output_channels[1:] + + self.conv2 = nn.Conv2d(btlnck_features, btlnck_features, + kernel_size=1, stride=1, padding=0) # btlnck conv + + if bin_centers_type == "normed": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayer + elif bin_centers_type == "softplus": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid1": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid2": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayer + else: + raise ValueError( + "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") + + self.seed_bin_regressor = SeedBinRegressorLayer( + btlnck_features, n_bins=n_bins, min_depth=min_depth, max_depth=max_depth) + self.seed_projector = Projector(btlnck_features, bin_embedding_dim) + self.projectors = nn.ModuleList([ + Projector(num_out, bin_embedding_dim) + for num_out in num_out_features + ]) + self.attractors = nn.ModuleList([ + Attractor(bin_embedding_dim, n_bins, n_attractors=n_attractors[i], min_depth=min_depth, max_depth=max_depth, + alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type) + for i in range(len(num_out_features)) + ]) + + last_in = N_MIDAS_OUT + 1 # +1 for relative depth + + # use log binomial instead of softmax + self.conditional_log_binomial = ConditionalLogBinomial( + last_in, bin_embedding_dim, n_classes=n_bins, min_temp=min_temp, max_temp=max_temp) + + def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs): + """ + Args: + x (torch.Tensor): Input image tensor of shape (B, C, H, W) + return_final_centers (bool, optional): Whether to return the final bin centers. Defaults to False. + denorm (bool, optional): Whether to denormalize the input image. This reverses ImageNet normalization as midas normalization is different. Defaults to False. + return_probs (bool, optional): Whether to return the output probability distribution. Defaults to False. + + Returns: + dict: Dictionary containing the following keys: + - rel_depth (torch.Tensor): Relative depth map of shape (B, H, W) + - metric_depth (torch.Tensor): Metric depth map of shape (B, 1, H, W) + - bin_centers (torch.Tensor): Bin centers of shape (B, n_bins). Present only if return_final_centers is True + - probs (torch.Tensor): Output probability distribution of shape (B, n_bins, H, W). Present only if return_probs is True + + """ + b, c, h, w = x.shape + # print("input shape ", x.shape) + self.orig_input_width = w + self.orig_input_height = h + rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) + # print("output shapes", rel_depth.shape, out.shape) + + outconv_activation = out[0] + btlnck = out[1] + x_blocks = out[2:] + + x_d0 = self.conv2(btlnck) + x = x_d0 + _, seed_b_centers = self.seed_bin_regressor(x) + + if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': + b_prev = (seed_b_centers - self.min_depth) / \ + (self.max_depth - self.min_depth) + else: + b_prev = seed_b_centers + + prev_b_embedding = self.seed_projector(x) + + # unroll this loop for better performance + for projector, attractor, x in zip(self.projectors, self.attractors, x_blocks): + b_embedding = projector(x) + b, b_centers = attractor( + b_embedding, b_prev, prev_b_embedding, interpolate=True) + b_prev = b.clone() + prev_b_embedding = b_embedding.clone() + + last = outconv_activation + + if self.inverse_midas: + # invert depth followed by normalization + rel_depth = 1.0 / (rel_depth + 1e-6) + rel_depth = (rel_depth - rel_depth.min()) / \ + (rel_depth.max() - rel_depth.min()) + # concat rel depth with last. First interpolate rel depth to last size + rel_cond = rel_depth.unsqueeze(1) + rel_cond = nn.functional.interpolate( + rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True) + last = torch.cat([last, rel_cond], dim=1) + + b_embedding = nn.functional.interpolate( + b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) + x = self.conditional_log_binomial(last, b_embedding) + + # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor + # print(x.shape, b_centers.shape) + b_centers = nn.functional.interpolate( + b_centers, x.shape[-2:], mode='bilinear', align_corners=True) + out = torch.sum(x * b_centers, dim=1, keepdim=True) + + # Structure output dict + output = dict(metric_depth=out) + if return_final_centers or return_probs: + output['bin_centers'] = b_centers + + if return_probs: + output['probs'] = x + + return output + + def get_lr_params(self, lr): + """ + Learning rate configuration for different layers of the model + Args: + lr (float) : Base learning rate + Returns: + list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. + """ + param_conf = [] + if self.train_midas: + if self.encoder_lr_factor > 0: + param_conf.append({'params': self.core.get_enc_params_except_rel_pos( + ), 'lr': lr / self.encoder_lr_factor}) + + if self.pos_enc_lr_factor > 0: + param_conf.append( + {'params': self.core.get_rel_pos_params(), 'lr': lr / self.pos_enc_lr_factor}) + + midas_params = self.core.core.scratch.parameters() + midas_lr_factor = self.midas_lr_factor + param_conf.append( + {'params': midas_params, 'lr': lr / midas_lr_factor}) + + remaining_modules = [] + for name, child in self.named_children(): + if name != 'core': + remaining_modules.append(child) + remaining_params = itertools.chain( + *[child.parameters() for child in remaining_modules]) + + param_conf.append({'params': remaining_params, 'lr': lr}) + + return param_conf + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): + core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, + train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) + model = ZoeDepth(core, **kwargs) + if pretrained_resource: + assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" + model = load_state_from_resource(model, pretrained_resource) + return model + + @staticmethod + def build_from_config(config): + return ZoeDepth.build(**config) diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/__init__.py b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..513a278b939c10c010e3c0250ec73544d5663886 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/__init__.py @@ -0,0 +1,31 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from .zoedepth_nk_v1 import ZoeDepthNK + +all_versions = { + "v1": ZoeDepthNK, +} + +get_version = lambda v : all_versions[v] \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json new file mode 100644 index 0000000000000000000000000000000000000000..42bab2a3ad159a09599a5aba270c491021a3cf1a --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json @@ -0,0 +1,67 @@ +{ + "model": { + "name": "ZoeDepthNK", + "version_name": "v1", + "bin_conf" : [ + { + "name": "nyu", + "n_bins": 64, + "min_depth": 1e-3, + "max_depth": 10.0 + }, + { + "name": "kitti", + "n_bins": 64, + "min_depth": 1e-3, + "max_depth": 80.0 + } + ], + "bin_embedding_dim": 128, + "bin_centers_type": "softplus", + "n_attractors":[16, 8, 4, 1], + "attractor_alpha": 1000, + "attractor_gamma": 2, + "attractor_kind" : "mean", + "attractor_type" : "inv", + "min_temp": 0.0212, + "max_temp": 50.0, + "memory_efficient": true, + "midas_model_type" : "DPT_BEiT_L_384", + "img_size": [384, 512] + }, + + "train": { + "train_midas": true, + "use_pretrained_midas": true, + "trainer": "zoedepth_nk", + "epochs": 5, + "bs": 16, + "optim_kwargs": {"lr": 0.0002512, "wd": 0.01}, + "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, + "same_lr": false, + "w_si": 1, + "w_domain": 100, + "avoid_boundary": false, + "random_crop": false, + "input_width": 640, + "input_height": 480, + "w_grad": 0, + "w_reg": 0, + "midas_lr_factor": 10, + "encoder_lr_factor":10, + "pos_enc_lr_factor":10 + }, + + "infer": { + "train_midas": false, + "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", + "use_pretrained_midas": false, + "force_keep_ar": true + }, + + "eval": { + "train_midas": false, + "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", + "use_pretrained_midas": false + } +} \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..7368ae8031188a9f946d9d3f29633c96e791e68e --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py @@ -0,0 +1,333 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import itertools + +import torch +import torch.nn as nn + +from zoedepth.models.depth_model import DepthModel +from zoedepth.models.base_models.midas import MidasCore +from zoedepth.models.layers.attractor import AttractorLayer, AttractorLayerUnnormed +from zoedepth.models.layers.dist_layers import ConditionalLogBinomial +from zoedepth.models.layers.localbins_layers import (Projector, SeedBinRegressor, + SeedBinRegressorUnnormed) +from zoedepth.models.layers.patch_transformer import PatchTransformerEncoder +from zoedepth.models.model_io import load_state_from_resource + + +class ZoeDepthNK(DepthModel): + def __init__(self, core, bin_conf, bin_centers_type="softplus", bin_embedding_dim=128, + n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', + min_temp=5, max_temp=50, + memory_efficient=False, train_midas=True, + is_midas_pretrained=True, midas_lr_factor=1, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs): + """ZoeDepthNK model. This is the version of ZoeDepth that has two metric heads and uses a learned router to route to experts. + + Args: + core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features + + bin_conf (List[dict]): A list of dictionaries that contain the bin configuration for each metric head. Each dictionary should contain the following keys: + "name" (str, typically same as the dataset name), "n_bins" (int), "min_depth" (float), "max_depth" (float) + + The length of this list determines the number of metric heads. + bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. + For "softplus", softplus activation is used and thus are unbounded. Defaults to "normed". + bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. + + n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. + attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. + attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. + attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. + attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. + + min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. + max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. + + memory_efficient (bool, optional): Whether to use memory efficient version of attractor layers. Memory efficient version is slower but is recommended incase of multiple metric heads in order save GPU memory. Defaults to False. + + train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. + is_midas_pretrained (bool, optional): Is "core" pretrained? Defaults to True. + midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. + encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. + pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. + + """ + + super().__init__() + + self.core = core + self.bin_conf = bin_conf + self.min_temp = min_temp + self.max_temp = max_temp + self.memory_efficient = memory_efficient + self.train_midas = train_midas + self.is_midas_pretrained = is_midas_pretrained + self.midas_lr_factor = midas_lr_factor + self.encoder_lr_factor = encoder_lr_factor + self.pos_enc_lr_factor = pos_enc_lr_factor + self.inverse_midas = inverse_midas + + N_MIDAS_OUT = 32 + btlnck_features = self.core.output_channels[0] + num_out_features = self.core.output_channels[1:] + # self.scales = [16, 8, 4, 2] # spatial scale factors + + self.conv2 = nn.Conv2d( + btlnck_features, btlnck_features, kernel_size=1, stride=1, padding=0) + + # Transformer classifier on the bottleneck + self.patch_transformer = PatchTransformerEncoder( + btlnck_features, 1, 128, use_class_token=True) + self.mlp_classifier = nn.Sequential( + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, 2) + ) + + if bin_centers_type == "normed": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayer + elif bin_centers_type == "softplus": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid1": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid2": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayer + else: + raise ValueError( + "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") + self.bin_centers_type = bin_centers_type + # We have bins for each bin conf. + # Create a map (ModuleDict) of 'name' -> seed_bin_regressor + self.seed_bin_regressors = nn.ModuleDict( + {conf['name']: SeedBinRegressorLayer(btlnck_features, conf["n_bins"], mlp_dim=bin_embedding_dim//2, min_depth=conf["min_depth"], max_depth=conf["max_depth"]) + for conf in bin_conf} + ) + + self.seed_projector = Projector( + btlnck_features, bin_embedding_dim, mlp_dim=bin_embedding_dim//2) + self.projectors = nn.ModuleList([ + Projector(num_out, bin_embedding_dim, mlp_dim=bin_embedding_dim//2) + for num_out in num_out_features + ]) + + # Create a map (ModuleDict) of 'name' -> attractors (ModuleList) + self.attractors = nn.ModuleDict( + {conf['name']: nn.ModuleList([ + Attractor(bin_embedding_dim, n_attractors[i], + mlp_dim=bin_embedding_dim, alpha=attractor_alpha, + gamma=attractor_gamma, kind=attractor_kind, + attractor_type=attractor_type, memory_efficient=memory_efficient, + min_depth=conf["min_depth"], max_depth=conf["max_depth"]) + for i in range(len(n_attractors)) + ]) + for conf in bin_conf} + ) + + last_in = N_MIDAS_OUT + # conditional log binomial for each bin conf + self.conditional_log_binomial = nn.ModuleDict( + {conf['name']: ConditionalLogBinomial(last_in, bin_embedding_dim, conf['n_bins'], bottleneck_factor=4, min_temp=self.min_temp, max_temp=self.max_temp) + for conf in bin_conf} + ) + + def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs): + """ + Args: + x (torch.Tensor): Input image tensor of shape (B, C, H, W). Assumes all images are from the same domain. + return_final_centers (bool, optional): Whether to return the final centers of the attractors. Defaults to False. + denorm (bool, optional): Whether to denormalize the input image. Defaults to False. + return_probs (bool, optional): Whether to return the probabilities of the bins. Defaults to False. + + Returns: + dict: Dictionary of outputs with keys: + - "rel_depth": Relative depth map of shape (B, 1, H, W) + - "metric_depth": Metric depth map of shape (B, 1, H, W) + - "domain_logits": Domain logits of shape (B, 2) + - "bin_centers": Bin centers of shape (B, N, H, W). Present only if return_final_centers is True + - "probs": Bin probabilities of shape (B, N, H, W). Present only if return_probs is True + """ + b, c, h, w = x.shape + self.orig_input_width = w + self.orig_input_height = h + rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) + + outconv_activation = out[0] + btlnck = out[1] + x_blocks = out[2:] + + x_d0 = self.conv2(btlnck) + x = x_d0 + + # Predict which path to take + embedding = self.patch_transformer(x)[0] # N, E + domain_logits = self.mlp_classifier(embedding) # N, 2 + domain_vote = torch.softmax(domain_logits.sum( + dim=0, keepdim=True), dim=-1) # 1, 2 + + # Get the path + bin_conf_name = ["nyu", "kitti"][torch.argmax( + domain_vote, dim=-1).squeeze().item()] + + try: + conf = [c for c in self.bin_conf if c.name == bin_conf_name][0] + except IndexError: + raise ValueError( + f"bin_conf_name {bin_conf_name} not found in bin_confs") + + min_depth = conf['min_depth'] + max_depth = conf['max_depth'] + + seed_bin_regressor = self.seed_bin_regressors[bin_conf_name] + _, seed_b_centers = seed_bin_regressor(x) + if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': + b_prev = (seed_b_centers - min_depth)/(max_depth - min_depth) + else: + b_prev = seed_b_centers + prev_b_embedding = self.seed_projector(x) + + attractors = self.attractors[bin_conf_name] + for projector, attractor, x in zip(self.projectors, attractors, x_blocks): + b_embedding = projector(x) + b, b_centers = attractor( + b_embedding, b_prev, prev_b_embedding, interpolate=True) + b_prev = b + prev_b_embedding = b_embedding + + last = outconv_activation + + b_centers = nn.functional.interpolate( + b_centers, last.shape[-2:], mode='bilinear', align_corners=True) + b_embedding = nn.functional.interpolate( + b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) + + clb = self.conditional_log_binomial[bin_conf_name] + x = clb(last, b_embedding) + + # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor + # print(x.shape, b_centers.shape) + # b_centers = nn.functional.interpolate(b_centers, x.shape[-2:], mode='bilinear', align_corners=True) + out = torch.sum(x * b_centers, dim=1, keepdim=True) + + output = dict(domain_logits=domain_logits, metric_depth=out) + if return_final_centers or return_probs: + output['bin_centers'] = b_centers + + if return_probs: + output['probs'] = x + return output + + def get_lr_params(self, lr): + """ + Learning rate configuration for different layers of the model + + Args: + lr (float) : Base learning rate + Returns: + list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. + """ + param_conf = [] + if self.train_midas: + def get_rel_pos_params(): + for name, p in self.core.core.pretrained.named_parameters(): + if "relative_position" in name: + yield p + + def get_enc_params_except_rel_pos(): + for name, p in self.core.core.pretrained.named_parameters(): + if "relative_position" not in name: + yield p + + encoder_params = get_enc_params_except_rel_pos() + rel_pos_params = get_rel_pos_params() + midas_params = self.core.core.scratch.parameters() + midas_lr_factor = self.midas_lr_factor if self.is_midas_pretrained else 1.0 + param_conf.extend([ + {'params': encoder_params, 'lr': lr / self.encoder_lr_factor}, + {'params': rel_pos_params, 'lr': lr / self.pos_enc_lr_factor}, + {'params': midas_params, 'lr': lr / midas_lr_factor} + ]) + + remaining_modules = [] + for name, child in self.named_children(): + if name != 'core': + remaining_modules.append(child) + remaining_params = itertools.chain( + *[child.parameters() for child in remaining_modules]) + param_conf.append({'params': remaining_params, 'lr': lr}) + return param_conf + + def get_conf_parameters(self, conf_name): + """ + Returns parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + params = [] + for name, child in self.named_children(): + if isinstance(child, nn.ModuleDict): + for bin_conf_name, module in child.items(): + if bin_conf_name == conf_name: + params += list(module.parameters()) + return params + + def freeze_conf(self, conf_name): + """ + Freezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + for p in self.get_conf_parameters(conf_name): + p.requires_grad = False + + def unfreeze_conf(self, conf_name): + """ + Unfreezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + for p in self.get_conf_parameters(conf_name): + p.requires_grad = True + + def freeze_all_confs(self): + """ + Freezes all the parameters of all the ModuleDicts children + """ + for name, child in self.named_children(): + if isinstance(child, nn.ModuleDict): + for bin_conf_name, module in child.items(): + for p in module.parameters(): + p.requires_grad = False + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): + core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, + train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) + model = ZoeDepthNK(core, **kwargs) + if pretrained_resource: + assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" + model = load_state_from_resource(model, pretrained_resource) + return model + + @staticmethod + def build_from_config(config): + return ZoeDepthNK.build(**config) diff --git a/src/flux/annotator/zoe/zoedepth/trainers/base_trainer.py b/src/flux/annotator/zoe/zoedepth/trainers/base_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..33fbbea3a7d49efe11b005adb5127f441eabfaf6 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/trainers/base_trainer.py @@ -0,0 +1,326 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os +import uuid +import warnings +from datetime import datetime as dt +from typing import Dict + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.optim as optim +import wandb +from tqdm import tqdm + +from zoedepth.utils.config import flatten +from zoedepth.utils.misc import RunningAverageDict, colorize, colors + + +def is_rank_zero(args): + return args.rank == 0 + + +class BaseTrainer: + def __init__(self, config, model, train_loader, test_loader=None, device=None): + """ Base Trainer class for training a model.""" + + self.config = config + self.metric_criterion = "abs_rel" + if device is None: + device = torch.device( + 'cuda') if torch.cuda.is_available() else torch.device('cpu') + self.device = device + self.model = model + self.train_loader = train_loader + self.test_loader = test_loader + self.optimizer = self.init_optimizer() + self.scheduler = self.init_scheduler() + + def resize_to_target(self, prediction, target): + if prediction.shape[2:] != target.shape[-2:]: + prediction = nn.functional.interpolate( + prediction, size=target.shape[-2:], mode="bilinear", align_corners=True + ) + return prediction + + def load_ckpt(self, checkpoint_dir="./checkpoints", ckpt_type="best"): + import glob + import os + + from zoedepth.models.model_io import load_wts + + if hasattr(self.config, "checkpoint"): + checkpoint = self.config.checkpoint + elif hasattr(self.config, "ckpt_pattern"): + pattern = self.config.ckpt_pattern + matches = glob.glob(os.path.join( + checkpoint_dir, f"*{pattern}*{ckpt_type}*")) + if not (len(matches) > 0): + raise ValueError(f"No matches found for the pattern {pattern}") + checkpoint = matches[0] + else: + return + model = load_wts(self.model, checkpoint) + # TODO : Resuming training is not properly supported in this repo. Implement loading / saving of optimizer and scheduler to support it. + print("Loaded weights from {0}".format(checkpoint)) + warnings.warn( + "Resuming training is not properly supported in this repo. Implement loading / saving of optimizer and scheduler to support it.") + self.model = model + + def init_optimizer(self): + m = self.model.module if self.config.multigpu else self.model + + if self.config.same_lr: + print("Using same LR") + if hasattr(m, 'core'): + m.core.unfreeze() + params = self.model.parameters() + else: + print("Using diff LR") + if not hasattr(m, 'get_lr_params'): + raise NotImplementedError( + f"Model {m.__class__.__name__} does not implement get_lr_params. Please implement it or use the same LR for all parameters.") + + params = m.get_lr_params(self.config.lr) + + return optim.AdamW(params, lr=self.config.lr, weight_decay=self.config.wd) + + def init_scheduler(self): + lrs = [l['lr'] for l in self.optimizer.param_groups] + return optim.lr_scheduler.OneCycleLR(self.optimizer, lrs, epochs=self.config.epochs, steps_per_epoch=len(self.train_loader), + cycle_momentum=self.config.cycle_momentum, + base_momentum=0.85, max_momentum=0.95, div_factor=self.config.div_factor, final_div_factor=self.config.final_div_factor, pct_start=self.config.pct_start, three_phase=self.config.three_phase) + + def train_on_batch(self, batch, train_step): + raise NotImplementedError + + def validate_on_batch(self, batch, val_step): + raise NotImplementedError + + def raise_if_nan(self, losses): + for key, value in losses.items(): + if torch.isnan(value): + raise ValueError(f"{key} is NaN, Stopping training") + + @property + def iters_per_epoch(self): + return len(self.train_loader) + + @property + def total_iters(self): + return self.config.epochs * self.iters_per_epoch + + def should_early_stop(self): + if self.config.get('early_stop', False) and self.step > self.config.early_stop: + return True + + def train(self): + print(f"Training {self.config.name}") + if self.config.uid is None: + self.config.uid = str(uuid.uuid4()).split('-')[-1] + run_id = f"{dt.now().strftime('%d-%h_%H-%M')}-{self.config.uid}" + self.config.run_id = run_id + self.config.experiment_id = f"{self.config.name}{self.config.version_name}_{run_id}" + self.should_write = ((not self.config.distributed) + or self.config.rank == 0) + self.should_log = self.should_write # and logging + if self.should_log: + tags = self.config.tags.split( + ',') if self.config.tags != '' else None + wandb.init(project=self.config.project, name=self.config.experiment_id, config=flatten(self.config), dir=self.config.root, + tags=tags, notes=self.config.notes, settings=wandb.Settings(start_method="fork")) + + self.model.train() + self.step = 0 + best_loss = np.inf + validate_every = int(self.config.validate_every * self.iters_per_epoch) + + + if self.config.prefetch: + + for i, batch in tqdm(enumerate(self.train_loader), desc=f"Prefetching...", + total=self.iters_per_epoch) if is_rank_zero(self.config) else enumerate(self.train_loader): + pass + + losses = {} + def stringify_losses(L): return "; ".join(map( + lambda kv: f"{colors.fg.purple}{kv[0]}{colors.reset}: {round(kv[1].item(),3):.4e}", L.items())) + for epoch in range(self.config.epochs): + if self.should_early_stop(): + break + + self.epoch = epoch + ################################# Train loop ########################################################## + if self.should_log: + wandb.log({"Epoch": epoch}, step=self.step) + pbar = tqdm(enumerate(self.train_loader), desc=f"Epoch: {epoch + 1}/{self.config.epochs}. Loop: Train", + total=self.iters_per_epoch) if is_rank_zero(self.config) else enumerate(self.train_loader) + for i, batch in pbar: + if self.should_early_stop(): + print("Early stopping") + break + # print(f"Batch {self.step+1} on rank {self.config.rank}") + losses = self.train_on_batch(batch, i) + # print(f"trained batch {self.step+1} on rank {self.config.rank}") + + self.raise_if_nan(losses) + if is_rank_zero(self.config) and self.config.print_losses: + pbar.set_description( + f"Epoch: {epoch + 1}/{self.config.epochs}. Loop: Train. Losses: {stringify_losses(losses)}") + self.scheduler.step() + + if self.should_log and self.step % 50 == 0: + wandb.log({f"Train/{name}": loss.item() + for name, loss in losses.items()}, step=self.step) + + self.step += 1 + + ######################################################################################################## + + if self.test_loader: + if (self.step % validate_every) == 0: + self.model.eval() + if self.should_write: + self.save_checkpoint( + f"{self.config.experiment_id}_latest.pt") + + ################################# Validation loop ################################################## + # validate on the entire validation set in every process but save only from rank 0, I know, inefficient, but avoids divergence of processes + metrics, test_losses = self.validate() + # print("Validated: {}".format(metrics)) + if self.should_log: + wandb.log( + {f"Test/{name}": tloss for name, tloss in test_losses.items()}, step=self.step) + + wandb.log({f"Metrics/{k}": v for k, + v in metrics.items()}, step=self.step) + + if (metrics[self.metric_criterion] < best_loss) and self.should_write: + self.save_checkpoint( + f"{self.config.experiment_id}_best.pt") + best_loss = metrics[self.metric_criterion] + + self.model.train() + + if self.config.distributed: + dist.barrier() + # print(f"Validated: {metrics} on device {self.config.rank}") + + # print(f"Finished step {self.step} on device {self.config.rank}") + ################################################################################################# + + # Save / validate at the end + self.step += 1 # log as final point + self.model.eval() + self.save_checkpoint(f"{self.config.experiment_id}_latest.pt") + if self.test_loader: + + ################################# Validation loop ################################################## + metrics, test_losses = self.validate() + # print("Validated: {}".format(metrics)) + if self.should_log: + wandb.log({f"Test/{name}": tloss for name, + tloss in test_losses.items()}, step=self.step) + wandb.log({f"Metrics/{k}": v for k, + v in metrics.items()}, step=self.step) + + if (metrics[self.metric_criterion] < best_loss) and self.should_write: + self.save_checkpoint( + f"{self.config.experiment_id}_best.pt") + best_loss = metrics[self.metric_criterion] + + self.model.train() + + def validate(self): + with torch.no_grad(): + losses_avg = RunningAverageDict() + metrics_avg = RunningAverageDict() + for i, batch in tqdm(enumerate(self.test_loader), desc=f"Epoch: {self.epoch + 1}/{self.config.epochs}. Loop: Validation", total=len(self.test_loader), disable=not is_rank_zero(self.config)): + metrics, losses = self.validate_on_batch(batch, val_step=i) + + if losses: + losses_avg.update(losses) + if metrics: + metrics_avg.update(metrics) + + return metrics_avg.get_value(), losses_avg.get_value() + + def save_checkpoint(self, filename): + if not self.should_write: + return + root = self.config.save_dir + if not os.path.isdir(root): + os.makedirs(root) + + fpath = os.path.join(root, filename) + m = self.model.module if self.config.multigpu else self.model + torch.save( + { + "model": m.state_dict(), + "optimizer": None, # TODO : Change to self.optimizer.state_dict() if resume support is needed, currently None to reduce file size + "epoch": self.epoch + }, fpath) + + def log_images(self, rgb: Dict[str, list] = {}, depth: Dict[str, list] = {}, scalar_field: Dict[str, list] = {}, prefix="", scalar_cmap="jet", min_depth=None, max_depth=None): + if not self.should_log: + return + + if min_depth is None: + try: + min_depth = self.config.min_depth + max_depth = self.config.max_depth + except AttributeError: + min_depth = None + max_depth = None + + depth = {k: colorize(v, vmin=min_depth, vmax=max_depth) + for k, v in depth.items()} + scalar_field = {k: colorize( + v, vmin=None, vmax=None, cmap=scalar_cmap) for k, v in scalar_field.items()} + images = {**rgb, **depth, **scalar_field} + wimages = { + prefix+"Predictions": [wandb.Image(v, caption=k) for k, v in images.items()]} + wandb.log(wimages, step=self.step) + + def log_line_plot(self, data): + if not self.should_log: + return + + plt.plot(data) + plt.ylabel("Scale factors") + wandb.log({"Scale factors": wandb.Image(plt)}, step=self.step) + plt.close() + + def log_bar_plot(self, title, labels, values): + if not self.should_log: + return + + data = [[label, val] for (label, val) in zip(labels, values)] + table = wandb.Table(data=data, columns=["label", "value"]) + wandb.log({title: wandb.plot.bar(table, "label", + "value", title=title)}, step=self.step) diff --git a/src/flux/annotator/zoe/zoedepth/trainers/builder.py b/src/flux/annotator/zoe/zoedepth/trainers/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..a663541b08912ebedce21a68c7599ce4c06e85d0 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/trainers/builder.py @@ -0,0 +1,48 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from importlib import import_module + + +def get_trainer(config): + """Builds and returns a trainer based on the config. + + Args: + config (dict): the config dict (typically constructed using utils.config.get_config) + config.trainer (str): the name of the trainer to use. The module named "{config.trainer}_trainer" must exist in trainers root module + + Raises: + ValueError: If the specified trainer does not exist under trainers/ folder + + Returns: + Trainer (inherited from zoedepth.trainers.BaseTrainer): The Trainer object + """ + assert "trainer" in config and config.trainer is not None and config.trainer != '', "Trainer not specified. Config: {0}".format( + config) + try: + Trainer = getattr(import_module( + f"zoedepth.trainers.{config.trainer}_trainer"), 'Trainer') + except ModuleNotFoundError as e: + raise ValueError(f"Trainer {config.trainer}_trainer not found.") from e + return Trainer diff --git a/src/flux/annotator/zoe/zoedepth/trainers/loss.py b/src/flux/annotator/zoe/zoedepth/trainers/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5a1c15cdf5628c1474c566fdc6e58159d7f5ab --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/trainers/loss.py @@ -0,0 +1,316 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.cuda.amp as amp +import numpy as np + + +KEY_OUTPUT = 'metric_depth' + + +def extract_key(prediction, key): + if isinstance(prediction, dict): + return prediction[key] + return prediction + + +# Main loss function used for ZoeDepth. Copy/paste from AdaBins repo (https://github.com/shariqfarooq123/AdaBins/blob/0952d91e9e762be310bb4cd055cbfe2448c0ce20/loss.py#L7) +class SILogLoss(nn.Module): + """SILog loss (pixel-wise)""" + def __init__(self, beta=0.15): + super(SILogLoss, self).__init__() + self.name = 'SILog' + self.beta = beta + + def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False): + input = extract_key(input, KEY_OUTPUT) + if input.shape[-1] != target.shape[-1] and interpolate: + input = nn.functional.interpolate( + input, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = input + else: + intr_input = input + + if target.ndim == 3: + target = target.unsqueeze(1) + + if mask is not None: + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + input = input[mask] + target = target[mask] + + with amp.autocast(enabled=False): # amp causes NaNs in this loss function + alpha = 1e-7 + g = torch.log(input + alpha) - torch.log(target + alpha) + + # n, c, h, w = g.shape + # norm = 1/(h*w) + # Dg = norm * torch.sum(g**2) - (0.85/(norm**2)) * (torch.sum(g))**2 + + Dg = torch.var(g) + self.beta * torch.pow(torch.mean(g), 2) + + loss = 10 * torch.sqrt(Dg) + + if torch.isnan(loss): + print("Nan SILog loss") + print("input:", input.shape) + print("target:", target.shape) + print("G", torch.sum(torch.isnan(g))) + print("Input min max", torch.min(input), torch.max(input)) + print("Target min max", torch.min(target), torch.max(target)) + print("Dg", torch.isnan(Dg)) + print("loss", torch.isnan(loss)) + + if not return_interpolated: + return loss + + return loss, intr_input + + +def grad(x): + # x.shape : n, c, h, w + diff_x = x[..., 1:, 1:] - x[..., 1:, :-1] + diff_y = x[..., 1:, 1:] - x[..., :-1, 1:] + mag = diff_x**2 + diff_y**2 + # angle_ratio + angle = torch.atan(diff_y / (diff_x + 1e-10)) + return mag, angle + + +def grad_mask(mask): + return mask[..., 1:, 1:] & mask[..., 1:, :-1] & mask[..., :-1, 1:] + + +class GradL1Loss(nn.Module): + """Gradient loss""" + def __init__(self): + super(GradL1Loss, self).__init__() + self.name = 'GradL1' + + def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False): + input = extract_key(input, KEY_OUTPUT) + if input.shape[-1] != target.shape[-1] and interpolate: + input = nn.functional.interpolate( + input, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = input + else: + intr_input = input + + grad_gt = grad(target) + grad_pred = grad(input) + mask_g = grad_mask(mask) + + loss = nn.functional.l1_loss(grad_pred[0][mask_g], grad_gt[0][mask_g]) + loss = loss + \ + nn.functional.l1_loss(grad_pred[1][mask_g], grad_gt[1][mask_g]) + if not return_interpolated: + return loss + return loss, intr_input + + +class OrdinalRegressionLoss(object): + + def __init__(self, ord_num, beta, discretization="SID"): + self.ord_num = ord_num + self.beta = beta + self.discretization = discretization + + def _create_ord_label(self, gt): + N,one, H, W = gt.shape + # print("gt shape:", gt.shape) + + ord_c0 = torch.ones(N, self.ord_num, H, W).to(gt.device) + if self.discretization == "SID": + label = self.ord_num * torch.log(gt) / np.log(self.beta) + else: + label = self.ord_num * (gt - 1.0) / (self.beta - 1.0) + label = label.long() + mask = torch.linspace(0, self.ord_num - 1, self.ord_num, requires_grad=False) \ + .view(1, self.ord_num, 1, 1).to(gt.device) + mask = mask.repeat(N, 1, H, W).contiguous().long() + mask = (mask > label) + ord_c0[mask] = 0 + ord_c1 = 1 - ord_c0 + # implementation according to the paper. + # ord_label = torch.ones(N, self.ord_num * 2, H, W).to(gt.device) + # ord_label[:, 0::2, :, :] = ord_c0 + # ord_label[:, 1::2, :, :] = ord_c1 + # reimplementation for fast speed. + ord_label = torch.cat((ord_c0, ord_c1), dim=1) + return ord_label, mask + + def __call__(self, prob, gt): + """ + :param prob: ordinal regression probability, N x 2*Ord Num x H x W, torch.Tensor + :param gt: depth ground truth, NXHxW, torch.Tensor + :return: loss: loss value, torch.float + """ + # N, C, H, W = prob.shape + valid_mask = gt > 0. + ord_label, mask = self._create_ord_label(gt) + # print("prob shape: {}, ord label shape: {}".format(prob.shape, ord_label.shape)) + entropy = -prob * ord_label + loss = torch.sum(entropy, dim=1)[valid_mask.squeeze(1)] + return loss.mean() + + +class DiscreteNLLLoss(nn.Module): + """Cross entropy loss""" + def __init__(self, min_depth=1e-3, max_depth=10, depth_bins=64): + super(DiscreteNLLLoss, self).__init__() + self.name = 'CrossEntropy' + self.ignore_index = -(depth_bins + 1) + # self._loss_func = nn.NLLLoss(ignore_index=self.ignore_index) + self._loss_func = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + self.min_depth = min_depth + self.max_depth = max_depth + self.depth_bins = depth_bins + self.alpha = 1 + self.zeta = 1 - min_depth + self.beta = max_depth + self.zeta + + def quantize_depth(self, depth): + # depth : N1HW + # output : NCHW + + # Quantize depth log-uniformly on [1, self.beta] into self.depth_bins bins + depth = torch.log(depth / self.alpha) / np.log(self.beta / self.alpha) + depth = depth * (self.depth_bins - 1) + depth = torch.round(depth) + depth = depth.long() + return depth + + + + def _dequantize_depth(self, depth): + """ + Inverse of quantization + depth : NCHW -> N1HW + """ + # Get the center of the bin + + + + + def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False): + input = extract_key(input, KEY_OUTPUT) + # assert torch.all(input <= 0), "Input should be negative" + + if input.shape[-1] != target.shape[-1] and interpolate: + input = nn.functional.interpolate( + input, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = input + else: + intr_input = input + + # assert torch.all(input)<=1) + if target.ndim == 3: + target = target.unsqueeze(1) + + target = self.quantize_depth(target) + if mask is not None: + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + # Set the mask to ignore_index + mask = mask.long() + input = input * mask + (1 - mask) * self.ignore_index + target = target * mask + (1 - mask) * self.ignore_index + + + + input = input.flatten(2) # N, nbins, H*W + target = target.flatten(1) # N, H*W + loss = self._loss_func(input, target) + + if not return_interpolated: + return loss + return loss, intr_input + + + + +def compute_scale_and_shift(prediction, target, mask): + # system matrix: A = [[a_00, a_01], [a_10, a_11]] + a_00 = torch.sum(mask * prediction * prediction, (1, 2)) + a_01 = torch.sum(mask * prediction, (1, 2)) + a_11 = torch.sum(mask, (1, 2)) + + # right hand side: b = [b_0, b_1] + b_0 = torch.sum(mask * prediction * target, (1, 2)) + b_1 = torch.sum(mask * target, (1, 2)) + + # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b + x_0 = torch.zeros_like(b_0) + x_1 = torch.zeros_like(b_1) + + det = a_00 * a_11 - a_01 * a_01 + # A needs to be a positive definite matrix. + valid = det > 0 + + x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] + x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] + + return x_0, x_1 +class ScaleAndShiftInvariantLoss(nn.Module): + def __init__(self): + super().__init__() + self.name = "SSILoss" + + def forward(self, prediction, target, mask, interpolate=True, return_interpolated=False): + + if prediction.shape[-1] != target.shape[-1] and interpolate: + prediction = nn.functional.interpolate(prediction, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = prediction + else: + intr_input = prediction + + + prediction, target, mask = prediction.squeeze(), target.squeeze(), mask.squeeze() + assert prediction.shape == target.shape, f"Shape mismatch: Expected same shape but got {prediction.shape} and {target.shape}." + + scale, shift = compute_scale_and_shift(prediction, target, mask) + + scaled_prediction = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) + + loss = nn.functional.l1_loss(scaled_prediction[mask], target[mask]) + if not return_interpolated: + return loss + return loss, intr_input + + + + +if __name__ == '__main__': + # Tests for DiscreteNLLLoss + celoss = DiscreteNLLLoss() + print(celoss(torch.rand(4, 64, 26, 32)*10, torch.rand(4, 1, 26, 32)*10, )) + + d = torch.Tensor([6.59, 3.8, 10.0]) + print(celoss.dequantize_depth(celoss.quantize_depth(d))) diff --git a/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py b/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..d528ae126f1c51b2f25fd31f94a39591ceb2f43a --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py @@ -0,0 +1,143 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.cuda.amp as amp +import torch.nn as nn + +from zoedepth.trainers.loss import GradL1Loss, SILogLoss +from zoedepth.utils.config import DATASETS_CONFIG +from zoedepth.utils.misc import compute_metrics + +from .base_trainer import BaseTrainer + + +class Trainer(BaseTrainer): + def __init__(self, config, model, train_loader, test_loader=None, device=None): + super().__init__(config, model, train_loader, + test_loader=test_loader, device=device) + self.device = device + self.silog_loss = SILogLoss() + self.grad_loss = GradL1Loss() + self.domain_classifier_loss = nn.CrossEntropyLoss() + + self.scaler = amp.GradScaler(enabled=self.config.use_amp) + + def train_on_batch(self, batch, train_step): + """ + Expects a batch of images and depth as input + batch["image"].shape : batch_size, c, h, w + batch["depth"].shape : batch_size, 1, h, w + + Assumes all images in a batch are from the same dataset + """ + + images, depths_gt = batch['image'].to( + self.device), batch['depth'].to(self.device) + # batch['dataset'] is a tensor strings all valued either 'nyu' or 'kitti'. labels nyu -> 0, kitti -> 1 + dataset = batch['dataset'][0] + # Convert to 0s or 1s + domain_labels = torch.Tensor([dataset == 'kitti' for _ in range( + images.size(0))]).to(torch.long).to(self.device) + + # m = self.model.module if self.config.multigpu else self.model + + b, c, h, w = images.size() + mask = batch["mask"].to(self.device).to(torch.bool) + + losses = {} + + with amp.autocast(enabled=self.config.use_amp): + output = self.model(images) + pred_depths = output['metric_depth'] + domain_logits = output['domain_logits'] + + l_si, pred = self.silog_loss( + pred_depths, depths_gt, mask=mask, interpolate=True, return_interpolated=True) + loss = self.config.w_si * l_si + losses[self.silog_loss.name] = l_si + + if self.config.w_grad > 0: + l_grad = self.grad_loss(pred, depths_gt, mask=mask) + loss = loss + self.config.w_grad * l_grad + losses[self.grad_loss.name] = l_grad + else: + l_grad = torch.Tensor([0]) + + if self.config.w_domain > 0: + l_domain = self.domain_classifier_loss( + domain_logits, domain_labels) + loss = loss + self.config.w_domain * l_domain + losses["DomainLoss"] = l_domain + else: + l_domain = torch.Tensor([0.]) + + self.scaler.scale(loss).backward() + + if self.config.clip_grad > 0: + self.scaler.unscale_(self.optimizer) + nn.utils.clip_grad_norm_( + self.model.parameters(), self.config.clip_grad) + + self.scaler.step(self.optimizer) + + if self.should_log and self.step > 1 and (self.step % int(self.config.log_images_every * self.iters_per_epoch)) == 0: + depths_gt[torch.logical_not(mask)] = -99 + self.log_images(rgb={"Input": images[0, ...]}, depth={"GT": depths_gt[0], "PredictedMono": pred[0]}, prefix="Train", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + self.scaler.update() + self.optimizer.zero_grad(set_to_none=True) + + return losses + + def validate_on_batch(self, batch, val_step): + images = batch['image'].to(self.device) + depths_gt = batch['depth'].to(self.device) + dataset = batch['dataset'][0] + if 'has_valid_depth' in batch: + if not batch['has_valid_depth']: + return None, None + + depths_gt = depths_gt.squeeze().unsqueeze(0).unsqueeze(0) + with amp.autocast(enabled=self.config.use_amp): + m = self.model.module if self.config.multigpu else self.model + pred_depths = m(images)["metric_depth"] + pred_depths = pred_depths.squeeze().unsqueeze(0).unsqueeze(0) + + mask = torch.logical_and( + depths_gt > self.config.min_depth, depths_gt < self.config.max_depth) + with amp.autocast(enabled=self.config.use_amp): + l_depth = self.silog_loss( + pred_depths, depths_gt, mask=mask.to(torch.bool), interpolate=True) + + metrics = compute_metrics(depths_gt, pred_depths, **self.config) + losses = {f"{self.silog_loss.name}": l_depth.item()} + + if val_step == 1 and self.should_log: + depths_gt[torch.logical_not(mask)] = -99 + self.log_images(rgb={"Input": images[0]}, depth={"GT": depths_gt[0], "PredictedMono": pred_depths[0]}, prefix="Test", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + return metrics, losses diff --git a/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py b/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3ac1c24c0512c1c1b191670a7c24abb4fca47ba1 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py @@ -0,0 +1,177 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.cuda.amp as amp +import torch.nn as nn + +from zoedepth.trainers.loss import GradL1Loss, SILogLoss +from zoedepth.utils.config import DATASETS_CONFIG +from zoedepth.utils.misc import compute_metrics +from zoedepth.data.preprocess import get_black_border + +from .base_trainer import BaseTrainer +from torchvision import transforms +from PIL import Image +import numpy as np + +class Trainer(BaseTrainer): + def __init__(self, config, model, train_loader, test_loader=None, device=None): + super().__init__(config, model, train_loader, + test_loader=test_loader, device=device) + self.device = device + self.silog_loss = SILogLoss() + self.grad_loss = GradL1Loss() + self.scaler = amp.GradScaler(enabled=self.config.use_amp) + + def train_on_batch(self, batch, train_step): + """ + Expects a batch of images and depth as input + batch["image"].shape : batch_size, c, h, w + batch["depth"].shape : batch_size, 1, h, w + """ + + images, depths_gt = batch['image'].to( + self.device), batch['depth'].to(self.device) + dataset = batch['dataset'][0] + + b, c, h, w = images.size() + mask = batch["mask"].to(self.device).to(torch.bool) + + losses = {} + + with amp.autocast(enabled=self.config.use_amp): + + output = self.model(images) + pred_depths = output['metric_depth'] + + l_si, pred = self.silog_loss( + pred_depths, depths_gt, mask=mask, interpolate=True, return_interpolated=True) + loss = self.config.w_si * l_si + losses[self.silog_loss.name] = l_si + + if self.config.w_grad > 0: + l_grad = self.grad_loss(pred, depths_gt, mask=mask) + loss = loss + self.config.w_grad * l_grad + losses[self.grad_loss.name] = l_grad + else: + l_grad = torch.Tensor([0]) + + self.scaler.scale(loss).backward() + + if self.config.clip_grad > 0: + self.scaler.unscale_(self.optimizer) + nn.utils.clip_grad_norm_( + self.model.parameters(), self.config.clip_grad) + + self.scaler.step(self.optimizer) + + if self.should_log and (self.step % int(self.config.log_images_every * self.iters_per_epoch)) == 0: + # -99 is treated as invalid depth in the log_images function and is colored grey. + depths_gt[torch.logical_not(mask)] = -99 + + self.log_images(rgb={"Input": images[0, ...]}, depth={"GT": depths_gt[0], "PredictedMono": pred[0]}, prefix="Train", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + if self.config.get("log_rel", False): + self.log_images( + scalar_field={"RelPred": output["relative_depth"][0]}, prefix="TrainRel") + + self.scaler.update() + self.optimizer.zero_grad() + + return losses + + @torch.no_grad() + def eval_infer(self, x): + with amp.autocast(enabled=self.config.use_amp): + m = self.model.module if self.config.multigpu else self.model + pred_depths = m(x)['metric_depth'] + return pred_depths + + @torch.no_grad() + def crop_aware_infer(self, x): + # if we are not avoiding the black border, we can just use the normal inference + if not self.config.get("avoid_boundary", False): + return self.eval_infer(x) + + # otherwise, we need to crop the image to avoid the black border + # For now, this may be a bit slow due to converting to numpy and back + # We assume no normalization is done on the input image + + # get the black border + assert x.shape[0] == 1, "Only batch size 1 is supported for now" + x_pil = transforms.ToPILImage()(x[0].cpu()) + x_np = np.array(x_pil, dtype=np.uint8) + black_border_params = get_black_border(x_np) + top, bottom, left, right = black_border_params.top, black_border_params.bottom, black_border_params.left, black_border_params.right + x_np_cropped = x_np[top:bottom, left:right, :] + x_cropped = transforms.ToTensor()(Image.fromarray(x_np_cropped)) + + # run inference on the cropped image + pred_depths_cropped = self.eval_infer(x_cropped.unsqueeze(0).to(self.device)) + + # resize the prediction to x_np_cropped's size + pred_depths_cropped = nn.functional.interpolate( + pred_depths_cropped, size=(x_np_cropped.shape[0], x_np_cropped.shape[1]), mode="bilinear", align_corners=False) + + + # pad the prediction back to the original size + pred_depths = torch.zeros((1, 1, x_np.shape[0], x_np.shape[1]), device=pred_depths_cropped.device, dtype=pred_depths_cropped.dtype) + pred_depths[:, :, top:bottom, left:right] = pred_depths_cropped + + return pred_depths + + + + def validate_on_batch(self, batch, val_step): + images = batch['image'].to(self.device) + depths_gt = batch['depth'].to(self.device) + dataset = batch['dataset'][0] + mask = batch["mask"].to(self.device) + if 'has_valid_depth' in batch: + if not batch['has_valid_depth']: + return None, None + + depths_gt = depths_gt.squeeze().unsqueeze(0).unsqueeze(0) + mask = mask.squeeze().unsqueeze(0).unsqueeze(0) + if dataset == 'nyu': + pred_depths = self.crop_aware_infer(images) + else: + pred_depths = self.eval_infer(images) + pred_depths = pred_depths.squeeze().unsqueeze(0).unsqueeze(0) + + with amp.autocast(enabled=self.config.use_amp): + l_depth = self.silog_loss( + pred_depths, depths_gt, mask=mask.to(torch.bool), interpolate=True) + + metrics = compute_metrics(depths_gt, pred_depths, **self.config) + losses = {f"{self.silog_loss.name}": l_depth.item()} + + if val_step == 1 and self.should_log: + depths_gt[torch.logical_not(mask)] = -99 + self.log_images(rgb={"Input": images[0]}, depth={"GT": depths_gt[0], "PredictedMono": pred_depths[0]}, prefix="Test", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + return metrics, losses diff --git a/src/flux/annotator/zoe/zoedepth/utils/__init__.py b/src/flux/annotator/zoe/zoedepth/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/utils/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/src/flux/annotator/zoe/zoedepth/utils/arg_utils.py b/src/flux/annotator/zoe/zoedepth/utils/arg_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8a3004ec3679c0a40fd8961253733fb4343ad545 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/utils/arg_utils.py @@ -0,0 +1,33 @@ + + +def infer_type(x): # hacky way to infer type from string args + if not isinstance(x, str): + return x + + try: + x = int(x) + return x + except ValueError: + pass + + try: + x = float(x) + return x + except ValueError: + pass + + return x + + +def parse_unknown(unknown_args): + clean = [] + for a in unknown_args: + if "=" in a: + k, v = a.split("=") + clean.extend([k, v]) + else: + clean.append(a) + + keys = clean[::2] + values = clean[1::2] + return {k.replace("--", ""): infer_type(v) for k, v in zip(keys, values)} diff --git a/src/flux/annotator/zoe/zoedepth/utils/config.py b/src/flux/annotator/zoe/zoedepth/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..84996564663dadf0e720de2a68ef8c53106ed666 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/utils/config.py @@ -0,0 +1,437 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import json +import os + +from .easydict import EasyDict as edict +from .arg_utils import infer_type + +import pathlib +import platform + +ROOT = pathlib.Path(__file__).parent.parent.resolve() + +HOME_DIR = os.path.expanduser("~") + +COMMON_CONFIG = { + "save_dir": os.path.expanduser("~/shortcuts/monodepth3_checkpoints"), + "project": "ZoeDepth", + "tags": '', + "notes": "", + "gpu": None, + "root": ".", + "uid": None, + "print_losses": False +} + +DATASETS_CONFIG = { + "kitti": { + "dataset": "kitti", + "min_depth": 0.001, + "max_depth": 80, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", + "input_height": 352, + "input_width": 1216, # 704 + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", + + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + + "do_random_rotate": True, + "degree": 1.0, + "do_kb_crop": True, + "garg_crop": True, + "eigen_crop": False, + "use_right": False + }, + "kitti_test": { + "dataset": "kitti", + "min_depth": 0.001, + "max_depth": 80, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", + "input_height": 352, + "input_width": 1216, + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", + + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + + "do_random_rotate": False, + "degree": 1.0, + "do_kb_crop": True, + "garg_crop": True, + "eigen_crop": False, + "use_right": False + }, + "nyu": { + "dataset": "nyu", + "avoid_boundary": False, + "min_depth": 1e-3, # originally 0.1 + "max_depth": 10, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), + "filenames_file": "./train_test_inputs/nyudepthv2_train_files_with_gt.txt", + "input_height": 480, + "input_width": 640, + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), + "filenames_file_eval": "./train_test_inputs/nyudepthv2_test_files_with_gt.txt", + "min_depth_eval": 1e-3, + "max_depth_eval": 10, + "min_depth_diff": -10, + "max_depth_diff": 10, + + "do_random_rotate": True, + "degree": 1.0, + "do_kb_crop": False, + "garg_crop": False, + "eigen_crop": True + }, + "ibims": { + "dataset": "ibims", + "ibims_root": os.path.join(HOME_DIR, "shortcuts/datasets/ibims/ibims1_core_raw/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "sunrgbd": { + "dataset": "sunrgbd", + "sunrgbd_root": os.path.join(HOME_DIR, "shortcuts/datasets/SUNRGBD/test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 8, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diml_indoor": { + "dataset": "diml_indoor", + "diml_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_indoor_test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diml_outdoor": { + "dataset": "diml_outdoor", + "diml_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_outdoor_test/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": False, + "min_depth_eval": 2, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "diode_indoor": { + "dataset": "diode_indoor", + "diode_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_indoor/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diode_outdoor": { + "dataset": "diode_outdoor", + "diode_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_outdoor/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "hypersim_test": { + "dataset": "hypersim_test", + "hypersim_test_root": os.path.join(HOME_DIR, "shortcuts/datasets/hypersim_test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 10 + }, + "vkitti": { + "dataset": "vkitti", + "vkitti_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti_test/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "vkitti2": { + "dataset": "vkitti2", + "vkitti2_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti2/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80, + }, + "ddad": { + "dataset": "ddad", + "ddad_root": os.path.join(HOME_DIR, "shortcuts/datasets/ddad/ddad_val/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80, + }, +} + +ALL_INDOOR = ["nyu", "ibims", "sunrgbd", "diode_indoor", "hypersim_test"] +ALL_OUTDOOR = ["kitti", "diml_outdoor", "diode_outdoor", "vkitti2", "ddad"] +ALL_EVAL_DATASETS = ALL_INDOOR + ALL_OUTDOOR + +COMMON_TRAINING_CONFIG = { + "dataset": "nyu", + "distributed": True, + "workers": 16, + "clip_grad": 0.1, + "use_shared_dict": False, + "shared_dict": None, + "use_amp": False, + + "aug": True, + "random_crop": False, + "random_translate": False, + "translate_prob": 0.2, + "max_translation": 100, + + "validate_every": 0.25, + "log_images_every": 0.1, + "prefetch": False, +} + + +def flatten(config, except_keys=('bin_conf')): + def recurse(inp): + if isinstance(inp, dict): + for key, value in inp.items(): + if key in except_keys: + yield (key, value) + if isinstance(value, dict): + yield from recurse(value) + else: + yield (key, value) + + return dict(list(recurse(config))) + + +def split_combined_args(kwargs): + """Splits the arguments that are combined with '__' into multiple arguments. + Combined arguments should have equal number of keys and values. + Keys are separated by '__' and Values are separated with ';'. + For example, '__n_bins__lr=256;0.001' + + Args: + kwargs (dict): key-value pairs of arguments where key-value is optionally combined according to the above format. + + Returns: + dict: Parsed dict with the combined arguments split into individual key-value pairs. + """ + new_kwargs = dict(kwargs) + for key, value in kwargs.items(): + if key.startswith("__"): + keys = key.split("__")[1:] + values = value.split(";") + assert len(keys) == len( + values), f"Combined arguments should have equal number of keys and values. Keys are separated by '__' and Values are separated with ';'. For example, '__n_bins__lr=256;0.001. Given (keys,values) is ({keys}, {values})" + for k, v in zip(keys, values): + new_kwargs[k] = v + return new_kwargs + + +def parse_list(config, key, dtype=int): + """Parse a list of values for the key if the value is a string. The values are separated by a comma. + Modifies the config in place. + """ + if key in config: + if isinstance(config[key], str): + config[key] = list(map(dtype, config[key].split(','))) + assert isinstance(config[key], list) and all([isinstance(e, dtype) for e in config[key]] + ), f"{key} should be a list of values dtype {dtype}. Given {config[key]} of type {type(config[key])} with values of type {[type(e) for e in config[key]]}." + + +def get_model_config(model_name, model_version=None): + """Find and parse the .json config file for the model. + + Args: + model_name (str): name of the model. The config file should be named config_{model_name}[_{model_version}].json under the models/{model_name} directory. + model_version (str, optional): Specific config version. If specified config_{model_name}_{model_version}.json is searched for and used. Otherwise config_{model_name}.json is used. Defaults to None. + + Returns: + easydict: the config dictionary for the model. + """ + config_fname = f"config_{model_name}_{model_version}.json" if model_version is not None else f"config_{model_name}.json" + config_file = os.path.join(ROOT, "models", model_name, config_fname) + if not os.path.exists(config_file): + return None + + with open(config_file, "r") as f: + config = edict(json.load(f)) + + # handle dictionary inheritance + # only training config is supported for inheritance + if "inherit" in config.train and config.train.inherit is not None: + inherit_config = get_model_config(config.train["inherit"]).train + for key, value in inherit_config.items(): + if key not in config.train: + config.train[key] = value + return edict(config) + + +def update_model_config(config, mode, model_name, model_version=None, strict=False): + model_config = get_model_config(model_name, model_version) + if model_config is not None: + config = {**config, ** + flatten({**model_config.model, **model_config[mode]})} + elif strict: + raise ValueError(f"Config file for model {model_name} not found.") + return config + + +def check_choices(name, value, choices): + # return # No checks in dev branch + if value not in choices: + raise ValueError(f"{name} {value} not in supported choices {choices}") + + +KEYS_TYPE_BOOL = ["use_amp", "distributed", "use_shared_dict", "same_lr", "aug", "three_phase", + "prefetch", "cycle_momentum"] # Casting is not necessary as their int casted values in config are 0 or 1 + + +def get_config(model_name, mode='train', dataset=None, **overwrite_kwargs): + """Main entry point to get the config for the model. + + Args: + model_name (str): name of the desired model. + mode (str, optional): "train" or "infer". Defaults to 'train'. + dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults to None. + + Keyword Args: key-value pairs of arguments to overwrite the default config. + + The order of precedence for overwriting the config is (Higher precedence first): + # 1. overwrite_kwargs + # 2. "config_version": Config file version if specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{config_version}.json + # 3. "version_name": Default Model version specific config specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{version_name}.json + # 4. common_config: Default config for all models specified in COMMON_CONFIG + + Returns: + easydict: The config dictionary for the model. + """ + + + check_choices("Model", model_name, ["zoedepth", "zoedepth_nk"]) + check_choices("Mode", mode, ["train", "infer", "eval"]) + if mode == "train": + check_choices("Dataset", dataset, ["nyu", "kitti", "mix", None]) + + config = flatten({**COMMON_CONFIG, **COMMON_TRAINING_CONFIG}) + config = update_model_config(config, mode, model_name) + + # update with model version specific config + version_name = overwrite_kwargs.get("version_name", config["version_name"]) + config = update_model_config(config, mode, model_name, version_name) + + # update with config version if specified + config_version = overwrite_kwargs.get("config_version", None) + if config_version is not None: + print("Overwriting config with config_version", config_version) + config = update_model_config(config, mode, model_name, config_version) + + # update with overwrite_kwargs + # Combined args are useful for hyperparameter search + overwrite_kwargs = split_combined_args(overwrite_kwargs) + config = {**config, **overwrite_kwargs} + + # Casting to bool # TODO: Not necessary. Remove and test + for key in KEYS_TYPE_BOOL: + if key in config: + config[key] = bool(config[key]) + + # Model specific post processing of config + parse_list(config, "n_attractors") + + # adjust n_bins for each bin configuration if bin_conf is given and n_bins is passed in overwrite_kwargs + if 'bin_conf' in config and 'n_bins' in overwrite_kwargs: + bin_conf = config['bin_conf'] # list of dicts + n_bins = overwrite_kwargs['n_bins'] + new_bin_conf = [] + for conf in bin_conf: + conf['n_bins'] = n_bins + new_bin_conf.append(conf) + config['bin_conf'] = new_bin_conf + + if mode == "train": + orig_dataset = dataset + if dataset == "mix": + dataset = 'nyu' # Use nyu as default for mix. Dataset config is changed accordingly while loading the dataloader + if dataset is not None: + config['project'] = f"MonoDepth3-{orig_dataset}" # Set project for wandb + + if dataset is not None: + config['dataset'] = dataset + config = {**DATASETS_CONFIG[dataset], **config} + + + config['model'] = model_name + typed_config = {k: infer_type(v) for k, v in config.items()} + # add hostname to config + config['hostname'] = platform.node() + return edict(typed_config) + + +def change_dataset(config, new_dataset): + config.update(DATASETS_CONFIG[new_dataset]) + return config diff --git a/src/flux/annotator/zoe/zoedepth/utils/easydict/__init__.py b/src/flux/annotator/zoe/zoedepth/utils/easydict/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15928179b0182c6045d98bc0a7be1c6ca45f675e --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/utils/easydict/__init__.py @@ -0,0 +1,158 @@ +""" +EasyDict +Copy/pasted from https://github.com/makinacorpus/easydict +Original author: Mathieu Leplatre +""" + +class EasyDict(dict): + """ + Get attributes + + >>> d = EasyDict({'foo':3}) + >>> d['foo'] + 3 + >>> d.foo + 3 + >>> d.bar + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'bar' + + Works recursively + + >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) + >>> isinstance(d.bar, dict) + True + >>> d.bar.x + 1 + + Bullet-proof + + >>> EasyDict({}) + {} + >>> EasyDict(d={}) + {} + >>> EasyDict(None) + {} + >>> d = {'a': 1} + >>> EasyDict(**d) + {'a': 1} + >>> EasyDict((('a', 1), ('b', 2))) + {'a': 1, 'b': 2} + + Set attributes + + >>> d = EasyDict() + >>> d.foo = 3 + >>> d.foo + 3 + >>> d.bar = {'prop': 'value'} + >>> d.bar.prop + 'value' + >>> d + {'foo': 3, 'bar': {'prop': 'value'}} + >>> d.bar.prop = 'newer' + >>> d.bar.prop + 'newer' + + + Values extraction + + >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) + >>> isinstance(d.bar, list) + True + >>> from operator import attrgetter + >>> list(map(attrgetter('x'), d.bar)) + [1, 3] + >>> list(map(attrgetter('y'), d.bar)) + [2, 4] + >>> d = EasyDict() + >>> list(d.keys()) + [] + >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) + >>> d.foo + 3 + >>> d.bar.x + 1 + + Still like a dict though + + >>> o = EasyDict({'clean':True}) + >>> list(o.items()) + [('clean', True)] + + And like a class + + >>> class Flower(EasyDict): + ... power = 1 + ... + >>> f = Flower() + >>> f.power + 1 + >>> f = Flower({'height': 12}) + >>> f.height + 12 + >>> f['power'] + 1 + >>> sorted(f.keys()) + ['height', 'power'] + + update and pop items + >>> d = EasyDict(a=1, b='2') + >>> e = EasyDict(c=3.0, a=9.0) + >>> d.update(e) + >>> d.c + 3.0 + >>> d['c'] + 3.0 + >>> d.get('c') + 3.0 + >>> d.update(a=4, b=4) + >>> d.b + 4 + >>> d.pop('a') + 4 + >>> d.a + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'a' + """ + def __init__(self, d=None, **kwargs): + if d is None: + d = {} + else: + d = dict(d) + if kwargs: + d.update(**kwargs) + for k, v in d.items(): + setattr(self, k, v) + # Class attributes + for k in self.__class__.__dict__.keys(): + if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'): + setattr(self, k, getattr(self, k)) + + def __setattr__(self, name, value): + if isinstance(value, (list, tuple)): + value = [self.__class__(x) + if isinstance(x, dict) else x for x in value] + elif isinstance(value, dict) and not isinstance(value, self.__class__): + value = self.__class__(value) + super(EasyDict, self).__setattr__(name, value) + super(EasyDict, self).__setitem__(name, value) + + __setitem__ = __setattr__ + + def update(self, e=None, **f): + d = e or dict() + d.update(f) + for k in d: + setattr(self, k, d[k]) + + def pop(self, k, d=None): + delattr(self, k) + return super(EasyDict, self).pop(k, d) + + +if __name__ == "__main__": + import doctest + doctest.testmod() \ No newline at end of file diff --git a/src/flux/annotator/zoe/zoedepth/utils/geometry.py b/src/flux/annotator/zoe/zoedepth/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..e3da8c75b5a8e39b4b58a4dcd827b84d79b9115c --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/utils/geometry.py @@ -0,0 +1,98 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import numpy as np + +def get_intrinsics(H,W): + """ + Intrinsics for a pinhole camera model. + Assume fov of 55 degrees and central principal point. + """ + f = 0.5 * W / np.tan(0.5 * 55 * np.pi / 180.0) + cx = 0.5 * W + cy = 0.5 * H + return np.array([[f, 0, cx], + [0, f, cy], + [0, 0, 1]]) + +def depth_to_points(depth, R=None, t=None): + + K = get_intrinsics(depth.shape[1], depth.shape[2]) + Kinv = np.linalg.inv(K) + if R is None: + R = np.eye(3) + if t is None: + t = np.zeros(3) + + # M converts from your coordinate to PyTorch3D's coordinate system + M = np.eye(3) + M[0, 0] = -1.0 + M[1, 1] = -1.0 + + height, width = depth.shape[1:3] + + x = np.arange(width) + y = np.arange(height) + coord = np.stack(np.meshgrid(x, y), -1) + coord = np.concatenate((coord, np.ones_like(coord)[:, :, [0]]), -1) # z=1 + coord = coord.astype(np.float32) + # coord = torch.as_tensor(coord, dtype=torch.float32, device=device) + coord = coord[None] # bs, h, w, 3 + + D = depth[:, :, :, None, None] + # print(D.shape, Kinv[None, None, None, ...].shape, coord[:, :, :, :, None].shape ) + pts3D_1 = D * Kinv[None, None, None, ...] @ coord[:, :, :, :, None] + # pts3D_1 live in your coordinate system. Convert them to Py3D's + pts3D_1 = M[None, None, None, ...] @ pts3D_1 + # from reference to targe tviewpoint + pts3D_2 = R[None, None, None, ...] @ pts3D_1 + t[None, None, None, :, None] + # pts3D_2 = pts3D_1 + # depth_2 = pts3D_2[:, :, :, 2, :] # b,1,h,w + return pts3D_2[:, :, :, :3, 0][0] + + +def create_triangles(h, w, mask=None): + """ + Reference: https://github.com/google-research/google-research/blob/e96197de06613f1b027d20328e06d69829fa5a89/infinite_nature/render_utils.py#L68 + Creates mesh triangle indices from a given pixel grid size. + This function is not and need not be differentiable as triangle indices are + fixed. + Args: + h: (int) denoting the height of the image. + w: (int) denoting the width of the image. + Returns: + triangles: 2D numpy array of indices (int) with shape (2(W-1)(H-1) x 3) + """ + x, y = np.meshgrid(range(w - 1), range(h - 1)) + tl = y * w + x + tr = y * w + x + 1 + bl = (y + 1) * w + x + br = (y + 1) * w + x + 1 + triangles = np.array([tl, bl, tr, br, tr, bl]) + triangles = np.transpose(triangles, (1, 2, 0)).reshape( + ((w - 1) * (h - 1) * 2, 3)) + if mask is not None: + mask = mask.reshape(-1) + triangles = triangles[mask[triangles].all(1)] + return triangles diff --git a/src/flux/annotator/zoe/zoedepth/utils/misc.py b/src/flux/annotator/zoe/zoedepth/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..4bbe403d3669829eecdf658458c76aa5e87e2b33 --- /dev/null +++ b/src/flux/annotator/zoe/zoedepth/utils/misc.py @@ -0,0 +1,368 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +"""Miscellaneous utility functions.""" + +from scipy import ndimage + +import base64 +import math +import re +from io import BytesIO + +import matplotlib +import matplotlib.cm +import numpy as np +import requests +import torch +import torch.distributed as dist +import torch.nn +import torch.nn as nn +import torch.utils.data.distributed +from PIL import Image +from torchvision.transforms import ToTensor + + +class RunningAverage: + def __init__(self): + self.avg = 0 + self.count = 0 + + def append(self, value): + self.avg = (value + self.count * self.avg) / (self.count + 1) + self.count += 1 + + def get_value(self): + return self.avg + + +def denormalize(x): + """Reverses the imagenet normalization applied to the input. + + Args: + x (torch.Tensor - shape(N,3,H,W)): input tensor + + Returns: + torch.Tensor - shape(N,3,H,W): Denormalized input + """ + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device) + return x * std + mean + + +class RunningAverageDict: + """A dictionary of running averages.""" + def __init__(self): + self._dict = None + + def update(self, new_dict): + if new_dict is None: + return + + if self._dict is None: + self._dict = dict() + for key, value in new_dict.items(): + self._dict[key] = RunningAverage() + + for key, value in new_dict.items(): + self._dict[key].append(value) + + def get_value(self): + if self._dict is None: + return None + return {key: value.get_value() for key, value in self._dict.items()} + + +def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None): + """Converts a depth map to a color image. + + Args: + value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed + vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None. + vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None. + cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'. + invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99. + invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None. + background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255). + gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False. + value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None. + + Returns: + numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4) + """ + if isinstance(value, torch.Tensor): + value = value.detach().cpu().numpy() + + value = value.squeeze() + if invalid_mask is None: + invalid_mask = value == invalid_val + mask = np.logical_not(invalid_mask) + + # normalize + vmin = np.percentile(value[mask],2) if vmin is None else vmin + vmax = np.percentile(value[mask],85) if vmax is None else vmax + if vmin != vmax: + value = (value - vmin) / (vmax - vmin) # vmin..vmax + else: + # Avoid 0-division + value = value * 0. + + # squeeze last dim if it exists + # grey out the invalid values + + value[invalid_mask] = np.nan + cmapper = matplotlib.cm.get_cmap(cmap) + if value_transform: + value = value_transform(value) + # value = value / value.max() + value = cmapper(value, bytes=True) # (nxmx4) + + # img = value[:, :, :] + img = value[...] + img[invalid_mask] = background_color + + # return img.transpose((2, 0, 1)) + if gamma_corrected: + # gamma correction + img = img / 255 + img = np.power(img, 2.2) + img = img * 255 + img = img.astype(np.uint8) + return img + + +def count_parameters(model, include_all=False): + return sum(p.numel() for p in model.parameters() if p.requires_grad or include_all) + + +def compute_errors(gt, pred): + """Compute metrics for 'pred' compared to 'gt' + + Args: + gt (numpy.ndarray): Ground truth values + pred (numpy.ndarray): Predicted values + + gt.shape should be equal to pred.shape + + Returns: + dict: Dictionary containing the following metrics: + 'a1': Delta1 accuracy: Fraction of pixels that are within a scale factor of 1.25 + 'a2': Delta2 accuracy: Fraction of pixels that are within a scale factor of 1.25^2 + 'a3': Delta3 accuracy: Fraction of pixels that are within a scale factor of 1.25^3 + 'abs_rel': Absolute relative error + 'rmse': Root mean squared error + 'log_10': Absolute log10 error + 'sq_rel': Squared relative error + 'rmse_log': Root mean squared error on the log scale + 'silog': Scale invariant log error + """ + thresh = np.maximum((gt / pred), (pred / gt)) + a1 = (thresh < 1.25).mean() + a2 = (thresh < 1.25 ** 2).mean() + a3 = (thresh < 1.25 ** 3).mean() + + abs_rel = np.mean(np.abs(gt - pred) / gt) + sq_rel = np.mean(((gt - pred) ** 2) / gt) + + rmse = (gt - pred) ** 2 + rmse = np.sqrt(rmse.mean()) + + rmse_log = (np.log(gt) - np.log(pred)) ** 2 + rmse_log = np.sqrt(rmse_log.mean()) + + err = np.log(pred) - np.log(gt) + silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100 + + log_10 = (np.abs(np.log10(gt) - np.log10(pred))).mean() + return dict(a1=a1, a2=a2, a3=a3, abs_rel=abs_rel, rmse=rmse, log_10=log_10, rmse_log=rmse_log, + silog=silog, sq_rel=sq_rel) + + +def compute_metrics(gt, pred, interpolate=True, garg_crop=False, eigen_crop=True, dataset='nyu', min_depth_eval=0.1, max_depth_eval=10, **kwargs): + """Compute metrics of predicted depth maps. Applies cropping and masking as necessary or specified via arguments. Refer to compute_errors for more details on metrics. + """ + if 'config' in kwargs: + config = kwargs['config'] + garg_crop = config.garg_crop + eigen_crop = config.eigen_crop + min_depth_eval = config.min_depth_eval + max_depth_eval = config.max_depth_eval + + if gt.shape[-2:] != pred.shape[-2:] and interpolate: + pred = nn.functional.interpolate( + pred, gt.shape[-2:], mode='bilinear', align_corners=True) + + pred = pred.squeeze().cpu().numpy() + pred[pred < min_depth_eval] = min_depth_eval + pred[pred > max_depth_eval] = max_depth_eval + pred[np.isinf(pred)] = max_depth_eval + pred[np.isnan(pred)] = min_depth_eval + + gt_depth = gt.squeeze().cpu().numpy() + valid_mask = np.logical_and( + gt_depth > min_depth_eval, gt_depth < max_depth_eval) + + if garg_crop or eigen_crop: + gt_height, gt_width = gt_depth.shape + eval_mask = np.zeros(valid_mask.shape) + + if garg_crop: + eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), + int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1 + + elif eigen_crop: + # print("-"*10, " EIGEN CROP ", "-"*10) + if dataset == 'kitti': + eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), + int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1 + else: + # assert gt_depth.shape == (480, 640), "Error: Eigen crop is currently only valid for (480, 640) images" + eval_mask[45:471, 41:601] = 1 + else: + eval_mask = np.ones(valid_mask.shape) + valid_mask = np.logical_and(valid_mask, eval_mask) + return compute_errors(gt_depth[valid_mask], pred[valid_mask]) + + +#################################### Model uilts ################################################ + + +def parallelize(config, model, find_unused_parameters=True): + + if config.gpu is not None: + torch.cuda.set_device(config.gpu) + model = model.cuda(config.gpu) + + config.multigpu = False + if config.distributed: + # Use DDP + config.multigpu = True + config.rank = config.rank * config.ngpus_per_node + config.gpu + dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url, + world_size=config.world_size, rank=config.rank) + config.batch_size = int(config.batch_size / config.ngpus_per_node) + # config.batch_size = 8 + config.workers = int( + (config.num_workers + config.ngpus_per_node - 1) / config.ngpus_per_node) + print("Device", config.gpu, "Rank", config.rank, "batch size", + config.batch_size, "Workers", config.workers) + torch.cuda.set_device(config.gpu) + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = model.cuda(config.gpu) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu], output_device=config.gpu, + find_unused_parameters=find_unused_parameters) + + elif config.gpu is None: + # Use DP + config.multigpu = True + model = model.cuda() + model = torch.nn.DataParallel(model) + + return model + + +################################################################################################# + + +##################################################################################################### + + +class colors: + '''Colors class: + Reset all colors with colors.reset + Two subclasses fg for foreground and bg for background. + Use as colors.subclass.colorname. + i.e. colors.fg.red or colors.bg.green + Also, the generic bold, disable, underline, reverse, strikethrough, + and invisible work with the main class + i.e. colors.bold + ''' + reset = '\033[0m' + bold = '\033[01m' + disable = '\033[02m' + underline = '\033[04m' + reverse = '\033[07m' + strikethrough = '\033[09m' + invisible = '\033[08m' + + class fg: + black = '\033[30m' + red = '\033[31m' + green = '\033[32m' + orange = '\033[33m' + blue = '\033[34m' + purple = '\033[35m' + cyan = '\033[36m' + lightgrey = '\033[37m' + darkgrey = '\033[90m' + lightred = '\033[91m' + lightgreen = '\033[92m' + yellow = '\033[93m' + lightblue = '\033[94m' + pink = '\033[95m' + lightcyan = '\033[96m' + + class bg: + black = '\033[40m' + red = '\033[41m' + green = '\033[42m' + orange = '\033[43m' + blue = '\033[44m' + purple = '\033[45m' + cyan = '\033[46m' + lightgrey = '\033[47m' + + +def printc(text, color): + print(f"{color}{text}{colors.reset}") + +############################################ + +def get_image_from_url(url): + response = requests.get(url) + img = Image.open(BytesIO(response.content)).convert("RGB") + return img + +def url_to_torch(url, size=(384, 384)): + img = get_image_from_url(url) + img = img.resize(size, Image.ANTIALIAS) + img = torch.from_numpy(np.asarray(img)).float() + img = img.permute(2, 0, 1) + img.div_(255) + return img + +def pil_to_batched_tensor(img): + return ToTensor()(img).unsqueeze(0) + +def save_raw_16bit(depth, fpath="raw.png"): + if isinstance(depth, torch.Tensor): + depth = depth.squeeze().cpu().numpy() + + assert isinstance(depth, np.ndarray), "Depth must be a torch tensor or numpy array" + assert depth.ndim == 2, "Depth must be 2D" + depth = depth * 256 # scale for 16-bit png + depth = depth.astype(np.uint16) + depth = Image.fromarray(depth) + depth.save(fpath) + print("Saved raw depth to", fpath) \ No newline at end of file diff --git a/src/flux/api.py b/src/flux/api.py new file mode 100644 index 0000000000000000000000000000000000000000..b08202adb35d2ffae320bb9b47f567e538837836 --- /dev/null +++ b/src/flux/api.py @@ -0,0 +1,194 @@ +import io +import os +import time +from pathlib import Path + +import requests +from PIL import Image + +API_ENDPOINT = "https://api.bfl.ml" + + +class ApiException(Exception): + def __init__(self, status_code: int, detail: str | list[dict] | None = None): + super().__init__() + self.detail = detail + self.status_code = status_code + + def __str__(self) -> str: + return self.__repr__() + + def __repr__(self) -> str: + if self.detail is None: + message = None + elif isinstance(self.detail, str): + message = self.detail + else: + message = "[" + ",".join(d["msg"] for d in self.detail) + "]" + return f"ApiException({self.status_code=}, {message=}, detail={self.detail})" + + +class ImageRequest: + def __init__( + self, + prompt: str, + width: int = 1024, + height: int = 1024, + name: str = "flux.1-pro", + num_steps: int = 50, + prompt_upsampling: bool = False, + seed: int | None = None, + validate: bool = True, + launch: bool = True, + api_key: str | None = None, + ): + """ + Manages an image generation request to the API. + + Args: + prompt: Prompt to sample + width: Width of the image in pixel + height: Height of the image in pixel + name: Name of the model + num_steps: Number of network evaluations + prompt_upsampling: Use prompt upsampling + seed: Fix the generation seed + validate: Run input validation + launch: Directly launches request + api_key: Your API key if not provided by the environment + + Raises: + ValueError: For invalid input + ApiException: For errors raised from the API + """ + if validate: + if name not in ["flux.1-pro"]: + raise ValueError(f"Invalid model {name}") + elif width % 32 != 0: + raise ValueError(f"width must be divisible by 32, got {width}") + elif not (256 <= width <= 1440): + raise ValueError(f"width must be between 256 and 1440, got {width}") + elif height % 32 != 0: + raise ValueError(f"height must be divisible by 32, got {height}") + elif not (256 <= height <= 1440): + raise ValueError(f"height must be between 256 and 1440, got {height}") + elif not (1 <= num_steps <= 50): + raise ValueError(f"steps must be between 1 and 50, got {num_steps}") + + self.request_json = { + "prompt": prompt, + "width": width, + "height": height, + "variant": name, + "steps": num_steps, + "prompt_upsampling": prompt_upsampling, + } + if seed is not None: + self.request_json["seed"] = seed + + self.request_id: str | None = None + self.result: dict | None = None + self._image_bytes: bytes | None = None + self._url: str | None = None + if api_key is None: + self.api_key = os.environ.get("BFL_API_KEY") + else: + self.api_key = api_key + + if launch: + self.request() + + def request(self): + """ + Request to generate the image. + """ + if self.request_id is not None: + return + response = requests.post( + f"{API_ENDPOINT}/v1/image", + headers={ + "accept": "application/json", + "x-key": self.api_key, + "Content-Type": "application/json", + }, + json=self.request_json, + ) + result = response.json() + if response.status_code != 200: + raise ApiException(status_code=response.status_code, detail=result.get("detail")) + self.request_id = response.json()["id"] + + def retrieve(self) -> dict: + """ + Wait for the generation to finish and retrieve response. + """ + if self.request_id is None: + self.request() + while self.result is None: + response = requests.get( + f"{API_ENDPOINT}/v1/get_result", + headers={ + "accept": "application/json", + "x-key": self.api_key, + }, + params={ + "id": self.request_id, + }, + ) + result = response.json() + if "status" not in result: + raise ApiException(status_code=response.status_code, detail=result.get("detail")) + elif result["status"] == "Ready": + self.result = result["result"] + elif result["status"] == "Pending": + time.sleep(0.5) + else: + raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'") + return self.result + + @property + def bytes(self) -> bytes: + """ + Generated image as bytes. + """ + if self._image_bytes is None: + response = requests.get(self.url) + if response.status_code == 200: + self._image_bytes = response.content + else: + raise ApiException(status_code=response.status_code) + return self._image_bytes + + @property + def url(self) -> str: + """ + Public url to retrieve the image from + """ + if self._url is None: + result = self.retrieve() + self._url = result["sample"] + return self._url + + @property + def image(self) -> Image.Image: + """ + Load the image as a PIL Image + """ + return Image.open(io.BytesIO(self.bytes)) + + def save(self, path: str): + """ + Save the generated image to a local path + """ + suffix = Path(self.url).suffix + if not path.endswith(suffix): + path = path + suffix + Path(path).resolve().parent.mkdir(parents=True, exist_ok=True) + with open(path, "wb") as file: + file.write(self.bytes) + + +if __name__ == "__main__": + from fire import Fire + + Fire(ImageRequest) diff --git a/src/flux/cli.py b/src/flux/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..f3624bc6c387f359162e68f46995b12ce341970a --- /dev/null +++ b/src/flux/cli.py @@ -0,0 +1,254 @@ +import os +import re +import time +from dataclasses import dataclass +from glob import iglob + +import torch +from einops import rearrange +from fire import Fire +from PIL import ExifTags, Image + +from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack +from flux.util import (configs, embed_watermark, load_ae, load_clip, + load_flow_model, load_t5) +from transformers import pipeline + +NSFW_THRESHOLD = 0.85 + +@dataclass +class SamplingOptions: + prompt: str + width: int + height: int + num_steps: int + guidance: float + seed: int | None + + +def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: + user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the prompt or write a command starting with a slash:\n" + "- '/w ' will set the width of the generated image\n" + "- '/h ' will set the height of the generated image\n" + "- '/s ' sets the next seed\n" + "- '/g ' sets the guidance (flux-dev only)\n" + "- '/n ' sets the number of steps\n" + "- '/q' to quit" + ) + + while (prompt := input(user_question)).startswith("/"): + if prompt.startswith("/w"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, width = prompt.split() + options.width = 16 * (int(width) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height *options.width/1e6:.2f}MP)" + ) + elif prompt.startswith("/h"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, height = prompt.split() + options.height = 16 * (int(height) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height *options.width/1e6:.2f}MP)" + ) + elif prompt.startswith("/g"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, guidance = prompt.split() + options.guidance = float(guidance) + print(f"Setting guidance to {options.guidance}") + elif prompt.startswith("/s"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, seed = prompt.split() + options.seed = int(seed) + print(f"Setting seed to {options.seed}") + elif prompt.startswith("/n"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, steps = prompt.split() + options.num_steps = int(steps) + print(f"Setting seed to {options.num_steps}") + elif prompt.startswith("/q"): + print("Quitting") + return None + else: + if not prompt.startswith("/h"): + print(f"Got invalid command '{prompt}'\n{usage}") + print(usage) + if prompt != "": + options.prompt = prompt + return options + + +@torch.inference_mode() +def main( + name: str = "flux-schnell", + width: int = 1360, + height: int = 768, + seed: int | None = None, + prompt: str = ( + "a photo of a forest with mist swirling around the tree trunks. The word " + '"FLUX" is painted over it in big, red brush strokes with visible texture' + ), + device: str = "cuda" if torch.cuda.is_available() else "cpu", + num_steps: int | None = None, + loop: bool = False, + guidance: float = 3.5, + offload: bool = False, + output_dir: str = "output", + add_sampling_metadata: bool = True, +): + """ + Sample the flux model. Either interactively (set `--loop`) or run for a + single image. + + Args: + name: Name of the model to load + height: height of the sample in pixels (should be a multiple of 16) + width: width of the sample in pixels (should be a multiple of 16) + seed: Set a seed for sampling + output_name: where to save the output image, `{idx}` will be replaced + by the index of the sample + prompt: Prompt used for sampling + device: Pytorch device + num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) + loop: start an interactive session and sample multiple times + guidance: guidance value used for guidance distillation + add_sampling_metadata: Add the prompt to the image Exif metadata + """ + nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection") + + if name not in configs: + available = ", ".join(configs.keys()) + raise ValueError(f"Got unknown model name: {name}, chose from {available}") + + torch_device = torch.device(device) + if num_steps is None: + num_steps = 4 if name == "flux-schnell" else 50 + + # allow for packing and conversion to latent space + height = 16 * (height // 16) + width = 16 * (width // 16) + + output_name = os.path.join(output_dir, "img_{idx}.jpg") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + idx = 0 + else: + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]\.jpg$", fn)] + if len(fns) > 0: + idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 + else: + idx = 0 + + # init all components + t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512) + clip = load_clip(torch_device) + model = load_flow_model(name, device="cpu" if offload else torch_device) + ae = load_ae(name, device="cpu" if offload else torch_device) + + rng = torch.Generator(device="cpu") + opts = SamplingOptions( + prompt=prompt, + width=width, + height=height, + num_steps=num_steps, + guidance=guidance, + seed=seed, + ) + + if loop: + opts = parse_prompt(opts) + + while opts is not None: + if opts.seed is None: + opts.seed = rng.seed() + print(f"Generating with seed {opts.seed}:\n{opts.prompt}") + t0 = time.perf_counter() + + # prepare input + x = get_noise( + 1, + opts.height, + opts.width, + device=torch_device, + dtype=torch.bfloat16, + seed=opts.seed, + ) + opts.seed = None + if offload: + ae = ae.cpu() + torch.cuda.empty_cache() + t5, clip = t5.to(torch_device), clip.to(torch_device) + inp = prepare(t5, clip, x, prompt=opts.prompt) + timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) + + # offload TEs to CPU, load model to gpu + if offload: + t5, clip = t5.cpu(), clip.cpu() + torch.cuda.empty_cache() + model = model.to(torch_device) + + # denoise initial noise + x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) + + # offload model, load autoencoder to gpu + if offload: + model.cpu() + torch.cuda.empty_cache() + ae.decoder.to(x.device) + + # decode latents to pixel space + x = unpack(x.float(), opts.height, opts.width) + with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + x = ae.decode(x) + t1 = time.perf_counter() + + fn = output_name.format(idx=idx) + print(f"Done in {t1 - t0:.1f}s. Saving {fn}") + # bring into PIL format and save + x = x.clamp(-1, 1) + x = embed_watermark(x.float()) + x = rearrange(x[0], "c h w -> h w c") + + img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) + nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] + + if nsfw_score < NSFW_THRESHOLD: + exif_data = Image.Exif() + exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" + exif_data[ExifTags.Base.Make] = "Black Forest Labs" + exif_data[ExifTags.Base.Model] = name + if add_sampling_metadata: + exif_data[ExifTags.Base.ImageDescription] = prompt + img.save(fn, exif=exif_data, quality=95, subsampling=0) + idx += 1 + else: + print("Your generated image may contain NSFW content.") + + if loop: + print("-" * 80) + opts = parse_prompt(opts) + else: + opts = None + + +def app(): + Fire(main) + + +if __name__ == "__main__": + app() diff --git a/src/flux/controlnet.py b/src/flux/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a04cc0234b2b726a550cbe62d027943f6bbcbb --- /dev/null +++ b/src/flux/controlnet.py @@ -0,0 +1,222 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +from einops import rearrange + +from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNetFlux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + _supports_gradient_checkpointing = True + + def __init__(self, params: FluxParams, controlnet_depth=2): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(controlnet_depth) + ] + ) + + # add ControlNet blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(controlnet_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.gradient_checkpointing = False + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(nn.Conv2d(16, 16, 3, padding=1)) + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + + @property + def attn_processors(self): + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + block_res_samples = () + + for block in self.double_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + txt, + vec, + pe, + ) + else: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + block_res_samples = block_res_samples + (img,) + + controlnet_block_res_samples = () + for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): + block_res_sample = controlnet_block(block_res_sample) + controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + + return controlnet_block_res_samples diff --git a/src/flux/math.py b/src/flux/math.py new file mode 100644 index 0000000000000000000000000000000000000000..0156bb6a205dec340e029f0c87cf70ae8709ae12 --- /dev/null +++ b/src/flux/math.py @@ -0,0 +1,30 @@ +import torch +from einops import rearrange +from torch import Tensor + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/src/flux/model.py b/src/flux/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e811491dd9861d92e00c5859ba7420ab20394e79 --- /dev/null +++ b/src/flux/model.py @@ -0,0 +1,329 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +from einops import rearrange + +from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + _supports_gradient_checkpointing = True + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + self.gradient_checkpointing = True # False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @property + def attn_processors(self): + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, # clip + block_controlnet_hidden_states=None, + guidance: Tensor | None = None, + image_proj: Tensor | None = None, + ip_scale: Tensor | float = 1.0, + use_share_weight_referencenet=False, + single_img_ids: Tensor | None = None, + single_block_refnet=False, + double_block_refnet=False, + ) -> Tensor: + if single_block_refnet or double_block_refnet: + assert use_share_weight_referencenet == True + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + # print("vec shape 1:", vec.shape) + # print("y shape 1:", y.shape) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + # print("vec shape 1.5:", vec.shape) + vec = vec + self.vector_in(y) + # print("vec shape 2:", vec.shape) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + if use_share_weight_referencenet: + # print("In img shape:", img.shape) + img_latent_length = img.shape[1] + single_ids = torch.cat((txt_ids, single_img_ids), dim=1) + single_pe = self.pe_embedder(single_ids) + if double_block_refnet and (not single_block_refnet): + double_block_pe = pe + double_block_img = img + single_block_pe = single_pe + + elif single_block_refnet and (not double_block_refnet): + double_block_pe = single_pe + double_block_img = img[:, img_latent_length//2:, :] + single_block_pe = pe + ref_img_latent = img[:, :img_latent_length//2, :] + else: + print("RefNet only support either double blocks or single blocks. If you want to turn on all blocks for RefNet, please use Spatial Condition.") + raise NotImplementedError + + if block_controlnet_hidden_states is not None: + controlnet_depth = len(block_controlnet_hidden_states) + for index_block, block in enumerate(self.double_blocks): + if self.training and self.gradient_checkpointing: + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + if not use_share_weight_referencenet: + img, txt = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + txt, + vec, + pe, + image_proj, + ip_scale, + use_reentrant=True, + ) + else: + double_block_img, txt = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + double_block_img, + txt, + vec, + double_block_pe, + image_proj, + ip_scale, + use_reentrant=True, + ) + else: + if not use_share_weight_referencenet: + img, txt = block( + img=img, + txt=txt, + vec=vec, + pe=pe, + image_proj=image_proj, + ip_scale=ip_scale, + ) + else: + double_block_img, txt = block( + img=double_block_img, + txt=txt, + vec=vec, + pe=double_block_pe, + image_proj=image_proj, + ip_scale=ip_scale, + ) + # controlnet residual + if block_controlnet_hidden_states is not None: + if not use_share_weight_referencenet: + img = img + block_controlnet_hidden_states[index_block % 2] + else: + double_block_img = double_block_img + block_controlnet_hidden_states[index_block % 2] + + if use_share_weight_referencenet: + mid_img = double_block_img + # print("After double blocks img shape:",mid_img.shape) + if double_block_refnet and (not single_block_refnet): + single_block_img = mid_img[:, img_latent_length//2:, :] + elif single_block_refnet and (not double_block_refnet): + single_block_img = torch.cat([ref_img_latent, mid_img], dim=1) + single_block_img = torch.cat((txt, single_block_img), 1) + else: + img = torch.cat((txt, img), 1) + # print("single block input img shape:", single_block_img.shape) + for block in self.single_blocks: + if self.training and self.gradient_checkpointing: + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + if not use_share_weight_referencenet: + img = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + vec, + pe, + use_reentrant=True, + ) + else: + single_block_img = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + single_block_img, + vec, + single_block_pe, + use_reentrant=True, + ) + else: + if not use_share_weight_referencenet: + img = block( + img, + vec=vec, + pe=pe, + ) + else: + single_block_img = block( + single_block_img, + vec=vec, + pe=single_block_pe, + ) + if use_share_weight_referencenet: + out_img = single_block_img + if double_block_refnet and (not single_block_refnet): + out_img = out_img[:, txt.shape[1]:, ...] + elif single_block_refnet and (not double_block_refnet): + out_img = out_img[:, txt.shape[1]:, ...] + out_img = out_img[:, img_latent_length//2:, :] + img = out_img + # print("output img shape:", img.shape) + else: + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img + + + +# In img shape: torch.Size([1, 2048, 3072]) +# After double blocks img shape: torch.Size([1, 1024, 3072]) +# single block input img shape: torch.Size([1, 2560, 3072]) +# output img shape: torch.Size([1, 1024, 3072]) +# +# In img shape: torch.Size([1, 2048, 3072]) +# After double blocks img shape: torch.Size([1, 2048, 3072]) [78/1966] +# single block input img shape: torch.Size([1, 1536, 3072]) +# output img shape: torch.Size([1, 1024, 3072]) \ No newline at end of file diff --git a/src/flux/modules/autoencoder.py b/src/flux/modules/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..75159f711f65f064107a1a1b9be6f09fc9872028 --- /dev/null +++ b/src/flux/modules/autoencoder.py @@ -0,0 +1,312 @@ +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) diff --git a/src/flux/modules/conditioner.py b/src/flux/modules/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..207a1693b2dd9e57b534b43e35f19ee100573aaa --- /dev/null +++ b/src/flux/modules/conditioner.py @@ -0,0 +1,39 @@ +from torch import Tensor, nn +from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel, + T5Tokenizer) + + +class HFEmbedder(nn.Module): + def __init__(self, version: str, max_length: int, **hf_kwargs): + super().__init__() + # self.is_clip = version.startswith("openai") + self.is_clip = "openai" in version + self.max_length = max_length + self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" + + if self.is_clip: + self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) + self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) + else: + self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) + self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) + + self.hf_module = self.hf_module.eval().requires_grad_(False) + + def forward(self, text: list[str]) -> Tensor: + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + + outputs = self.hf_module( + input_ids=batch_encoding["input_ids"].to(self.hf_module.device), + attention_mask=None, + output_hidden_states=False, + ) + return outputs[self.output_key] diff --git a/src/flux/modules/layers.py b/src/flux/modules/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..e9e745832b842c311bc19c951ccb410039a91cfa --- /dev/null +++ b/src/flux/modules/layers.py @@ -0,0 +1,625 @@ +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + +from ..math import attention, rope +import torch.nn.functional as F + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + t.device + ) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): + super().__init__() + + self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) + self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + +class FLuxSelfAttnProcessor: + def __call__(self, attn, x, pe, **attention_kwargs): + print('2' * 30) + + qkv = attn.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = attn.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = attn.proj(x) + return x + +class LoraFluxAttnProcessor(nn.Module): + + def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): + super().__init__() + self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha) + self.lora_weight = lora_weight + + + def __call__(self, attn, x, pe, **attention_kwargs): + qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = attn.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = attn.proj(x) + self.proj_lora(x) * self.lora_weight + print('1' * 30) + print(x.norm(), (self.proj_lora(x) * self.lora_weight).norm(), 'norm') + return x + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + def forward(): + pass + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + +class DoubleStreamBlockLoraProcessor(nn.Module): + def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): + super().__init__() + self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha) + self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha) + self.lora_weight = lora_weight + + def forward(self, attn, img, txt, vec, pe, **attention_kwargs): + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + return img, txt + +class IPDoubleStreamBlockProcessor(nn.Module): + """Attention processor for handling IP-adapter with double stream block.""" + + def __init__(self, context_dim, hidden_dim): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "IPDoubleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch." + ) + + # Ensure context_dim matches the dimension of image_proj + self.context_dim = context_dim + self.hidden_dim = hidden_dim + + # Initialize projections for IP-adapter + self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True) + self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True) + + nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight) + nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias) + + nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight) + nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias) + + def __call__(self, attn, img, txt, vec, pe, image_proj, ip_scale=1.0, **attention_kwargs): + + # Prepare image for attention + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = attn.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = attn.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, :txt.shape[1]], attn1[:, txt.shape[1]:] + + # print(f"txt_attn shape: {txt_attn.size()}") + # print(f"img_attn shape: {img_attn.size()}") + + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + + + # IP-adapter processing + ip_query = img_q # latent sample query + ip_key = self.ip_adapter_double_stream_k_proj(image_proj) + ip_value = self.ip_adapter_double_stream_v_proj(image_proj) + + # Reshape projections for multi-head attention + ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) + ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) + + # Compute attention between IP projections and the latent query + ip_attention = F.scaled_dot_product_attention( + ip_query, + ip_key, + ip_value, + dropout_p=0.0, + is_causal=False + ) + ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads, D=attn.head_dim) + + img = img + ip_scale * ip_attention + + return img, txt + +# TODO: ReferenceNet Block Processor +class DoubleStreamBlockProcessor: + def __call__(self, attn, img, txt, vec, pe, **attention_kwargs): + # img: latent; txt: text embedding; vec: timestep embedding; pe: postion embedding; + # prepare image for attention + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + # prepare image for attention + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + # if use_share_weight_referencenet: + # real_image_batch_size = 1 + # ref_img_modulated = img_modulated[:1, ...] + # print("ref latent shape:", ref_img_modulated.shape) + # noisy_img_modulated = img_modulated[1:, ...] + # print("noise latent shape:", noisy_img_modulated.shape) + # img_modulated = torch.cat([noisy_img_modulated, ref_img_modulated], dim=1) + # print("input latent shape:", img_modulated.shape) + img_qkv = attn.img_attn.qkv(img_modulated) + # print("img_qkv shape:", img_qkv.shape) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", B=1, K=3, H=attn.num_heads, D=attn.head_dim) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = attn.txt_norm1(txt) + + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + + txt_qkv = attn.txt_attn.qkv(txt_modulated) + # print("txt_qkv shape:", txt_qkv.shape) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", B=1, K=3, H=attn.num_heads, D=attn.head_dim) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + # print("txt_q shape:", txt_q.shape) + # print("img_q shape:", img_q.shape) + # print("txt_k shape:", txt_k.shape) + # print("img_k shape:", img_k.shape) + # print("txt_v shape:", txt_v.shape) + # print("img_v shape:", img_v.shape) + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] + # print(" txt_attn shape:", txt_attn.shape) + # print(" img_attn shape:", img_attn.shape) + # print(" img shape:", img.shape) + # calculate the img bloks + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + # raise NotImplementedError + # img: torch.Size([1, 1024, 3072]) + # txt: torch.Size([1, 512, 3072]) + # txt_q shape: torch.Size([1, 24, 512, 128]) + # img_q shape: torch.Size([1, 24, 1024, 128]) + # txt_attn shape: torch.Size([1, 512, 3072]) + # img_attn shape: torch.Size([1, 1024, 3072]) + # img shape: torch.Size([1, 1024, 3072]) + return img, txt + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + super().__init__() + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.head_dim = hidden_size // num_heads + + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + processor = DoubleStreamBlockProcessor() + self.set_processor(processor) + + def set_processor(self, processor) -> None: + self.processor = processor + + def get_processor(self): + return self.processor + + def forward( + self, + img: Tensor, + txt: Tensor, + vec: Tensor, + pe: Tensor, + image_proj: Tensor = None, + ip_scale: float =1.0, + use_share_weight_referencenet=False, + ) -> tuple[Tensor, Tensor]: + if image_proj is None and use_share_weight_referencenet: + return self.processor(self, img, txt, vec, pe, use_share_weight_referencenet) + elif image_proj is None: + return self.processor(self, img, txt, vec, pe) + else: + return self.processor(self, img, txt, vec, pe, image_proj, ip_scale) + +class IPSingleStreamBlockProcessor(nn.Module): + """Attention processor for handling IP-adapter with single stream block.""" + def __init__(self, context_dim, hidden_dim): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "IPSingleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch." + ) + + # Ensure context_dim matches the dimension of image_proj + self.context_dim = context_dim + self.hidden_dim = hidden_dim + + # Initialize projections for IP-adapter + self.ip_adapter_single_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=False) + self.ip_adapter_single_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=False) + + nn.init.zeros_(self.ip_adapter_single_stream_k_proj.weight) + nn.init.zeros_(self.ip_adapter_single_stream_v_proj.weight) + + def __call__( + self, + attn: nn.Module, + x: Tensor, + vec: Tensor, + pe: Tensor, + image_proj: Tensor | None = None, + ip_scale: float = 1.0 + ) -> Tensor: + + mod, _ = attn.modulation(vec) + x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift + qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + q, k = attn.norm(q, k, v) + + # compute attention + attn_1 = attention(q, k, v, pe=pe) + + # IP-adapter processing + ip_query = q + ip_key = self.ip_adapter_single_stream_k_proj(image_proj) + ip_value = self.ip_adapter_single_stream_v_proj(image_proj) + + # Reshape projections for multi-head attention + ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) + ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) + + + # Compute attention between IP projections and the latent query + ip_attention = F.scaled_dot_product_attention( + ip_query, + ip_key, + ip_value + ) + ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)") + + attn_out = attn_1 + ip_scale * ip_attention + + # compute activation in mlp stream, cat again and run second linear layer + output = attn.linear2(torch.cat((attn_out, attn.mlp_act(mlp)), 2)) + out = x + mod.gate * output + + return out + + +class SingleStreamBlockLoraProcessor(nn.Module): + def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1): + super().__init__() + self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha) + self.lora_weight = lora_weight + + def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + + mod, _ = attn.modulation(vec) + x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift + qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) + qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + q, k = attn.norm(q, k, v) + + # compute attention + attn_1 = attention(q, k, v, pe=pe) + + # compute activation in mlp stream, cat again and run second linear layer + output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) + output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight + output = x + mod.gate * output + return output + +# TODO: ReferenceNet Block Processor +class SingleStreamBlockProcessor: + def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + + mod, _ = attn.modulation(vec) + x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift + qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + q, k = attn.norm(q, k, v) + + # compute attention + attn_1 = attention(q, k, v, pe=pe) + + # compute activation in mlp stream, cat again and run second linear layer + output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) + output = x + mod.gate * output + return output + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = qk_scale or self.head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(self.head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + processor = SingleStreamBlockProcessor() + self.set_processor(processor) + + + def set_processor(self, processor) -> None: + self.processor = processor + + def get_processor(self): + return self.processor + + def forward( + self, + x: Tensor, + vec: Tensor, + pe: Tensor, + image_proj: Tensor | None = None, + ip_scale: float = 1.0 + ) -> Tensor: + if image_proj is None: + return self.processor(self, x, vec, pe) + else: + return self.processor(self, x, vec, pe, image_proj, ip_scale) + + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + +class ImageProjModel(torch.nn.Module): + """Projection Model + https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L28 + """ + + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + super().__init__() + + self.generator = None + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + diff --git a/src/flux/sampling.py b/src/flux/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..875927d025ddc5618e93a4c85992aad6ffb3846d --- /dev/null +++ b/src/flux/sampling.py @@ -0,0 +1,304 @@ +import math +from typing import Callable + +import torch +from einops import rearrange, repeat +from torch import Tensor + +from .model import Flux +from .modules.conditioner import HFEmbedder + + +def get_noise( + num_samples: int, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + seed: int, +): + return torch.randn( + num_samples, + 16, + # allow for packing + 2 * math.ceil(height / 16), + 2 * math.ceil(width / 16), + device=device, + dtype=dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + +def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str], use_spatial_condition=False, use_share_weight_referencenet=False, share_position_embedding=False) -> dict[str, Tensor]: + bs, c, h, w = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + if use_spatial_condition: + if share_position_embedding: + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + img_ids = torch.cat([img_ids, img_ids], dim=1) + else: + img_ids = torch.zeros(h // 2, w, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + elif use_share_weight_referencenet: + if share_position_embedding: + single_img_ids = torch.zeros(h // 2, w // 2, 3) + single_img_ids[..., 1] = single_img_ids[..., 1] + torch.arange(h // 2)[:, None] + single_img_ids[..., 2] = single_img_ids[..., 2] + torch.arange(w // 2)[None, :] + single_img_ids = repeat(single_img_ids, "h w c -> b (h w) c", b=bs) + img_ids = torch.cat([single_img_ids, single_img_ids], dim=1) + else: + # single_img_position_embedding + single_img_ids = torch.zeros(h // 2, w // 2, 3) + single_img_ids[..., 1] = single_img_ids[..., 1] + torch.arange(h // 2)[:, None] + single_img_ids[..., 2] = single_img_ids[..., 2] + torch.arange(w // 2)[None, :] + single_img_ids = repeat(single_img_ids, "h w c -> b (h w) c", b=bs) + # ref_and_noise_img_position_embedding + img_ids = torch.zeros(h // 2, w, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + else: + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + if isinstance(prompt, str): + prompt = [prompt] + txt = t5(prompt) + if txt.shape[0] == 1 and bs > 1: + txt = repeat(txt, "1 ... -> bs ...", bs=bs) + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + vec = clip(prompt) + if vec.shape[0] == 1 and bs > 1: + vec = repeat(vec, "1 ... -> bs ...", bs=bs) + if not use_share_weight_referencenet: + return { + "img": img, + "img_ids": img_ids.to(img.device), + "txt": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + "vec": vec.to(img.device), + } + else: + return { + "img": img, + "img_ids": img_ids.to(img.device), + "txt": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + "vec": vec.to(img.device), + "single_img_ids": single_img_ids.to(img.device), + } + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: Flux, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + neg_txt: Tensor, + neg_txt_ids: Tensor, + neg_vec: Tensor, + # sampling parameters + timesteps: list[float], + guidance: float = 4.0, + true_gs = 1, + timestep_to_start_cfg=0, + # ip-adapter parameters + image_proj: Tensor=None, + neg_image_proj: Tensor=None, + ip_scale: Tensor | float = 1.0, + neg_ip_scale: Tensor | float = 1.0, + source_image: Tensor=None, + use_share_weight_referencenet=False, + single_img_ids=None, + neg_single_img_ids=None, + single_block_refnet=False, + double_block_refnet=False, +): + i = 0 + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + if source_image is not None: # spatial condition or refnet + img = torch.cat([source_image, img],dim=-2) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + image_proj=image_proj, + ip_scale=ip_scale, + use_share_weight_referencenet=use_share_weight_referencenet, + single_img_ids=single_img_ids, + single_block_refnet=single_block_refnet, + double_block_refnet=double_block_refnet, + ) + if i >= timestep_to_start_cfg: + neg_pred = model( + img=img, + img_ids=img_ids, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + image_proj=neg_image_proj, + ip_scale=neg_ip_scale, + use_share_weight_referencenet=use_share_weight_referencenet, + single_img_ids=neg_single_img_ids, + single_block_refnet=single_block_refnet, + double_block_refnet=double_block_refnet, + ) + pred = neg_pred + true_gs * (pred - neg_pred) + if use_share_weight_referencenet: + zero_buffer = torch.zeros_like(pred) + pred = torch.cat([zero_buffer, pred], dim=1) + img = img + (t_prev - t_curr) * pred + if (source_image is not None): # spatial condition or refnet + latent_length = img.shape[-2] // 2 + img = img[:,latent_length:,:] + i += 1 + return img + +def denoise_controlnet( + model: Flux, + controlnet:None, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + neg_txt: Tensor, + neg_txt_ids: Tensor, + neg_vec: Tensor, + controlnet_cond, + # sampling parameters + timesteps: list[float], + guidance: float = 4.0, + true_gs = 1, + controlnet_gs=0.7, + timestep_to_start_cfg=0, + # ip-adapter parameters + image_proj: Tensor=None, + neg_image_proj: Tensor=None, + ip_scale: Tensor | float = 1, + neg_ip_scale: Tensor | float = 1, +): + # this is ignored for schnell + i = 0 + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + block_res_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_cond, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + block_controlnet_hidden_states=[i * controlnet_gs for i in block_res_samples], + image_proj=image_proj, + ip_scale=ip_scale, + ) + if i >= timestep_to_start_cfg: + neg_block_res_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_cond, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + neg_pred = model( + img=img, + img_ids=img_ids, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + block_controlnet_hidden_states=[i * controlnet_gs for i in neg_block_res_samples], + image_proj=neg_image_proj, + ip_scale=neg_ip_scale, + ) + pred = neg_pred + true_gs * (pred - neg_pred) + + img = img + (t_prev - t_curr) * pred + + i += 1 + return img + +def unpack(x: Tensor, height: int, width: int) -> Tensor: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) diff --git a/src/flux/util.py b/src/flux/util.py new file mode 100644 index 0000000000000000000000000000000000000000..186422b1b3a244697966b58e26fa3c1152d493cc --- /dev/null +++ b/src/flux/util.py @@ -0,0 +1,456 @@ +import os +from dataclasses import dataclass + +import torch +import json +import cv2 +import numpy as np +from PIL import Image +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from safetensors.torch import load_file as load_sft + +from optimum.quanto import requantize + +from .model import Flux, FluxParams +from .controlnet import ControlNetFlux +from .modules.autoencoder import AutoEncoder, AutoEncoderParams +from .modules.conditioner import HFEmbedder +from .annotator.dwpose import DWposeDetector +from .annotator.mlsd import MLSDdetector +from .annotator.canny import CannyDetector +from .annotator.midas import MidasDetector +from .annotator.hed import HEDdetector +from .annotator.tile import TileDetector +from .annotator.zoe import ZoeDetector + +def tensor_to_pil_image(in_image): + tensor = in_image.squeeze(0) + tensor = (tensor + 1) / 2 + tensor = tensor * 255 + numpy_array = tensor.permute(1, 2, 0).byte().numpy() + pil_image = Image.fromarray(numpy_array) + return pil_image + +def save_image(in_image, output_path): + tensor = in_image.squeeze(0) + tensor = (tensor + 1) / 2 + tensor = tensor * 255 + numpy_array = tensor.permute(1, 2, 0).byte().numpy() + image = Image.fromarray(numpy_array) + image.save(output_path) + +def load_safetensors(path): + tensors = {} + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + return tensors + +def get_lora_rank(checkpoint): + for k in checkpoint.keys(): + if k.endswith(".down.weight"): + return checkpoint[k].shape[0] + +def load_checkpoint(local_path, repo_id, name): + if local_path is not None: + if '.safetensors' in local_path: + print(f"Loading .safetensors checkpoint from {local_path}") + checkpoint = load_safetensors(local_path) + else: + print(f"Loading checkpoint from {local_path}") + checkpoint = torch.load(local_path, map_location='cpu') + elif repo_id is not None and name is not None: + print(f"Loading checkpoint {name} from repo id {repo_id}") + checkpoint = load_from_repo_id(repo_id, name) + else: + raise ValueError( + "LOADING ERROR: you must specify local_path or repo_id with name in HF to download" + ) + return checkpoint + + +def c_crop(image): + width, height = image.size + new_size = min(width, height) + left = (width - new_size) / 2 + top = (height - new_size) / 2 + right = (width + new_size) / 2 + bottom = (height + new_size) / 2 + return image.crop((left, top, right, bottom)) + +def pad64(x): + return int(np.ceil(float(x) / 64.0) * 64 - x) + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + +def safer_memory(x): + # Fix many MAC/AMD problems + return np.ascontiguousarray(x.copy()).copy() + +#https://github.com/Mikubill/sd-webui-controlnet/blob/main/scripts/processor.py#L17 +#Added upscale_method, mode params +def resize_image_with_pad(input_image, resolution, skip_hwc3=False, mode='edge'): + if skip_hwc3: + img = input_image + else: + img = HWC3(input_image) + H_raw, W_raw, _ = img.shape + if resolution == 0: + return img, lambda x: x + k = float(resolution) / float(min(H_raw, W_raw)) + H_target = int(np.round(float(H_raw) * k)) + W_target = int(np.round(float(W_raw) * k)) + img = cv2.resize(img, (W_target, H_target), interpolation=cv2.INTER_AREA) + H_pad, W_pad = pad64(H_target), pad64(W_target) + img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode) + + def remove_pad(x): + return safer_memory(x[:H_target, :W_target, ...]) + + return safer_memory(img_padded), remove_pad + +class Annotator: + def __init__(self, name: str, device: str): + if name == "canny": + processor = CannyDetector() + elif name == "openpose": + processor = DWposeDetector(device) + elif name == "depth": + processor = MidasDetector() + elif name == "hed": + processor = HEDdetector() + elif name == "hough": + processor = MLSDdetector() + elif name == "tile": + processor = TileDetector() + elif name == "zoe": + processor = ZoeDetector() + self.name = name + self.processor = processor + + def __call__(self, image: Image, width: int, height: int): + image = np.array(image) + detect_resolution = max(width, height) + image, remove_pad = resize_image_with_pad(image, detect_resolution) + + image = np.array(image) + if self.name == "canny": + result = self.processor(image, low_threshold=100, high_threshold=200) + elif self.name == "hough": + result = self.processor(image, thr_v=0.05, thr_d=5) + elif self.name == "depth": + result = self.processor(image) + result, _ = result + else: + result = self.processor(image) + + result = HWC3(remove_pad(result)) + result = cv2.resize(result, (width, height)) + return result + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + repo_id: str | None + repo_flow: str | None + repo_ae: str | None + repo_id_ae: str | None + + +configs = { + "flux-dev": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-fp8": ModelSpec( + repo_id="XLabs-AI/flux-dev-fp8", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux-dev-fp8.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV_FP8"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": ModelSpec( + repo_id="black-forest-labs/FLUX.1-schnell", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-schnell.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +def print_load_warning(missing: list[str], unexpected: list[str]) -> None: + if len(missing) > 0 and len(unexpected) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + print("\n" + "-" * 79 + "\n") + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + +def load_from_repo_id(repo_id, checkpoint_name): + ckpt_path = hf_hub_download(repo_id, checkpoint_name) + sd = load_sft(ckpt_path, device='cpu') + return sd + +def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + + with torch.device("meta" if ckpt_path is not None else device): + model = Flux(configs[name].params).to(torch.bfloat16) + + if ckpt_path is not None: + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return model + +def load_flow_model2(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors")) + + with torch.device("meta" if ckpt_path is not None else device): + model = Flux(configs[name].params) + + if ckpt_path is not None: + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return model + +def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json') + + + model = Flux(configs[name].params).to(torch.bfloat16) + + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device='cpu') + with open(json_path, "r") as f: + quantization_map = json.load(f) + print("Start a quantization process...") + requantize(model, sd, quantization_map, device=device) + print("Model is quantized!") + return model + +def load_controlnet(name, device, transformer=None): + with torch.device(device): + controlnet = ControlNetFlux(configs[name].params) + if transformer is not None: + controlnet.load_state_dict(transformer.state_dict(), strict=False) + return controlnet + +def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: + # max length 64, 128, 256 and 512 should work (if your sequence is short enough) + t5_path = os.getenv("T5") + if t5_path is None: + return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device) + else: + return HFEmbedder(t5_path, max_length=max_length, torch_dtype=torch.bfloat16).to(device) + +def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: + clip_path = os.getenv("CLIP_VIT") + if clip_path is None: + return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device) + else: + return HFEmbedder(clip_path, max_length=77, torch_dtype=torch.bfloat16).to(device) + + +def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: + ckpt_path = configs[name].ae_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_ae is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae) + + # Loading the autoencoder + print("Init AE") + with torch.device("meta" if ckpt_path is not None else device): + ae = AutoEncoder(configs[name].ae_params) + + if ckpt_path is not None: + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return ae + + +class WatermarkEmbedder: + def __init__(self, watermark): + self.watermark = watermark + self.num_bits = len(WATERMARK_BITS) + self.encoder = WatermarkEncoder() + self.encoder.set_watermark("bits", self.watermark) + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + """ + Adds a predefined watermark to the input image + + Args: + image: ([N,] B, RGB, H, W) in range [-1, 1] + + Returns: + same as input but watermarked + """ + image = 0.5 * image + 0.5 + squeeze = len(image.shape) == 4 + if squeeze: + image = image[None, ...] + n = image.shape[0] + image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1] + # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] + # watermarking libary expects input as cv2 BGR format + for k in range(image_np.shape[0]): + image_np[k] = self.encoder.encode(image_np[k], "dwtDct") + image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to( + image.device + ) + image = torch.clamp(image / 255, min=0.0, max=1.0) + if squeeze: + image = image[0] + image = 2 * image - 1 + return image + + +# A fixed 48-bit message that was choosen at random +WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110 +# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] diff --git a/src/flux/xflux_pipeline.py b/src/flux/xflux_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..cda3a073e84966dfac7742bab40720a3d5145a23 --- /dev/null +++ b/src/flux/xflux_pipeline.py @@ -0,0 +1,407 @@ +from PIL import Image, ExifTags +import numpy as np +import torch +from torch import Tensor + +from einops import rearrange +import uuid +import os + +from src.flux.modules.layers import ( + SingleStreamBlockProcessor, + DoubleStreamBlockProcessor, + SingleStreamBlockLoraProcessor, + DoubleStreamBlockLoraProcessor, + IPDoubleStreamBlockProcessor, + ImageProjModel, +) +from src.flux.sampling import denoise, denoise_controlnet, get_noise, get_schedule, prepare, unpack +from src.flux.util import ( + load_ae, + load_clip, + load_flow_model, + load_t5, + load_controlnet, + load_flow_model_quintized, + Annotator, + get_lora_rank, + load_checkpoint +) + +from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor + +class XFluxPipeline: + def __init__(self, model_type, device, offload: bool = False): + self.device = torch.device(device) + self.offload = offload + self.model_type = model_type + + self.clip = load_clip(self.device) + self.t5 = load_t5(self.device, max_length=512) + self.ae = load_ae(model_type, device="cpu" if offload else self.device) + if "fp8" in model_type: + self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device) + else: + self.model = load_flow_model(model_type, device="cpu" if offload else self.device) + + self.image_encoder_path = "openai/clip-vit-large-patch14" + self.hf_lora_collection = "XLabs-AI/flux-lora-collection" + self.lora_types_to_names = { + "realism": "lora.safetensors", + } + self.controlnet_loaded = False + self.ip_loaded = False + self.spatial_condition = False + self.share_position_embedding = False + self.use_share_weight_referencenet = False + self.single_block_refnet = False + self.double_block_refnet = False + + def set_ip(self, local_path: str = None, repo_id = None, name: str = None): + self.model.to(self.device) + + # unpack checkpoint + checkpoint = load_checkpoint(local_path, repo_id, name) + prefix = "double_blocks." + blocks = {} + proj = {} + + for key, value in checkpoint.items(): + if key.startswith(prefix): + blocks[key[len(prefix):].replace('.processor.', '.')] = value + if key.startswith("ip_adapter_proj_model"): + proj[key[len("ip_adapter_proj_model."):]] = value + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=torch.float16 + ) + self.clip_image_processor = CLIPImageProcessor() + + # setup image embedding projection model + self.improj = ImageProjModel(4096, 768, 4) + self.improj.load_state_dict(proj) + self.improj = self.improj.to(self.device, dtype=torch.bfloat16) + + ip_attn_procs = {} + + for name, _ in self.model.attn_processors.items(): + ip_state_dict = {} + for k in checkpoint.keys(): + if name in k: + ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k] + if ip_state_dict: + ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072) + ip_attn_procs[name].load_state_dict(ip_state_dict) + ip_attn_procs[name].to(self.device, dtype=torch.bfloat16) + else: + ip_attn_procs[name] = self.model.attn_processors[name] + + self.model.set_attn_processor(ip_attn_procs) + self.ip_loaded = True + + def set_lora(self, local_path: str = None, repo_id: str = None, + name: str = None, lora_weight: int = 0.7): + checkpoint = load_checkpoint(local_path, repo_id, name) + self.update_model_with_lora(checkpoint, lora_weight) + + def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7): + checkpoint = load_checkpoint( + None, self.hf_lora_collection, self.lora_types_to_names[lora_type] + ) + self.update_model_with_lora(checkpoint, lora_weight) + + def update_model_with_lora(self, checkpoint, lora_weight): + rank = get_lora_rank(checkpoint) + lora_attn_procs = {} + + for name, _ in self.model.attn_processors.items(): + lora_state_dict = {} + for k in checkpoint.keys(): + if name in k: + lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight + + if len(lora_state_dict): + if name.startswith("single_blocks"): + lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank) + else: + lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank) + lora_attn_procs[name].load_state_dict(lora_state_dict) + lora_attn_procs[name].to(self.device) + else: + if name.startswith("single_blocks"): + lora_attn_procs[name] = SingleStreamBlockProcessor() + else: + lora_attn_procs[name] = DoubleStreamBlockProcessor() + + self.model.set_attn_processor(lora_attn_procs) + + def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None): + self.model.to(self.device) + self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16) + + checkpoint = load_checkpoint(local_path, repo_id, name) + self.controlnet.load_state_dict(checkpoint, strict=False) + self.annotator = Annotator(control_type, self.device) + self.controlnet_loaded = True + self.control_type = control_type + + def get_image_proj( + self, + image_prompt: Tensor, + ): + # encode image-prompt embeds + image_prompt = self.clip_image_processor( + images=image_prompt, + return_tensors="pt" + ).pixel_values + image_prompt = image_prompt.to(self.image_encoder.device) + image_prompt_embeds = self.image_encoder( + image_prompt + ).image_embeds.to( + device=self.device, dtype=torch.bfloat16, + ) + # encode image + image_proj = self.improj(image_prompt_embeds) + return image_proj + + def __call__(self, + prompt: str, + image_prompt: Image = None, + source_image: Tensor = None, + controlnet_image: Image = None, + width: int = 512, + height: int = 512, + guidance: float = 4, + num_steps: int = 50, + seed: int = 123456789, + true_gs: float = 3.5, # 3 + control_weight: float = 0.9, + ip_scale: float = 1.0, + neg_ip_scale: float = 1.0, + neg_prompt: str = '', + neg_image_prompt: Image = None, + timestep_to_start_cfg: int = 1, # 0 + ): + width = 16 * (width // 16) + height = 16 * (height // 16) + image_proj = None + neg_image_proj = None + if not (image_prompt is None and neg_image_prompt is None) : + assert self.ip_loaded, 'You must setup IP-Adapter to add image prompt as input' + + if image_prompt is None: + image_prompt = np.zeros((width, height, 3), dtype=np.uint8) + if neg_image_prompt is None: + neg_image_prompt = np.zeros((width, height, 3), dtype=np.uint8) + + image_proj = self.get_image_proj(image_prompt) + neg_image_proj = self.get_image_proj(neg_image_prompt) + + if self.controlnet_loaded: + controlnet_image = self.annotator(controlnet_image, width, height) + controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + controlnet_image = controlnet_image.permute( + 2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device) + + return self.forward( + prompt, + width, + height, + guidance, + num_steps, + seed, + controlnet_image, + timestep_to_start_cfg=timestep_to_start_cfg, + true_gs=true_gs, + control_weight=control_weight, + neg_prompt=neg_prompt, + image_proj=image_proj, + neg_image_proj=neg_image_proj, + ip_scale=ip_scale, + neg_ip_scale=neg_ip_scale, + spatial_condition=self.spatial_condition, + source_image=source_image, + share_position_embedding=self.share_position_embedding + ) + + @torch.inference_mode() + def gradio_generate(self, prompt, image_prompt, controlnet_image, width, height, guidance, + num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt, + neg_image_prompt, timestep_to_start_cfg, control_type, control_weight, + lora_weight, local_path, lora_local_path, ip_local_path): + if controlnet_image is not None: + controlnet_image = Image.fromarray(controlnet_image) + if ((self.controlnet_loaded and control_type != self.control_type) + or not self.controlnet_loaded): + if local_path is not None: + self.set_controlnet(control_type, local_path=local_path) + else: + self.set_controlnet(control_type, local_path=None, + repo_id=f"xlabs-ai/flux-controlnet-{control_type}-v3", + name=f"flux-{control_type}-controlnet-v3.safetensors") + if lora_local_path is not None: + self.set_lora(local_path=lora_local_path, lora_weight=lora_weight) + if image_prompt is not None: + image_prompt = Image.fromarray(image_prompt) + if neg_image_prompt is not None: + neg_image_prompt = Image.fromarray(neg_image_prompt) + if not self.ip_loaded: + if ip_local_path is not None: + self.set_ip(local_path=ip_local_path) + else: + self.set_ip(repo_id="xlabs-ai/flux-ip-adapter", + name="flux-ip-adapter.safetensors") + seed = int(seed) + if seed == -1: + seed = torch.Generator(device="cpu").seed() + + img = self(prompt, image_prompt, controlnet_image, width, height, guidance, + num_steps, seed, true_gs, control_weight, ip_scale, neg_ip_scale, neg_prompt, + neg_image_prompt, timestep_to_start_cfg) + + filename = f"output/gradio/{uuid.uuid4()}.jpg" + os.makedirs(os.path.dirname(filename), exist_ok=True) + exif_data = Image.Exif() + exif_data[ExifTags.Base.Make] = "XLabs AI" + exif_data[ExifTags.Base.Model] = self.model_type + img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0) + return img, filename + + def forward( + self, + prompt, + width, + height, + guidance, + num_steps, + seed, + controlnet_image = None, + timestep_to_start_cfg = 0, + true_gs = 3.5, + control_weight = 0.9, + neg_prompt="", + image_proj=None, + neg_image_proj=None, + ip_scale=1.0, + neg_ip_scale=1.0, + spatial_condition=False, + source_image=None, + share_position_embedding=False + ): + x = get_noise( + 1, height, width, device=self.device, + dtype=torch.bfloat16, seed=seed + ) + timesteps = get_schedule( + num_steps, + (width // 8) * (height // 8) // (16 * 16), + shift=True, + ) + torch.manual_seed(seed) + with torch.no_grad(): + if self.offload: + self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) + # print("x noise shape:", x.shape) + inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt, use_spatial_condition=spatial_condition, share_position_embedding=share_position_embedding, use_share_weight_referencenet=self.use_share_weight_referencenet) + # print("input img noise shape:", inp_cond['img'].shape) + neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt, use_spatial_condition=spatial_condition, share_position_embedding=share_position_embedding, use_share_weight_referencenet=self.use_share_weight_referencenet) + if spatial_condition or self.use_share_weight_referencenet: + # TODO here: + source_image = self.ae.encode(source_image.to(self.device).to(torch.float32)) + # print("ae source image shape:", source_image.shape) + source_image = rearrange(source_image, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2).to(inp_cond['img'].dtype) + + # print("rearrange ae source image shape:", source_image.shape) + + if self.offload: + self.offload_model_to_cpu(self.t5, self.clip) + self.model = self.model.to(self.device) + if self.controlnet_loaded: + x = denoise_controlnet( + self.model, + img=inp_cond['img'], + img_ids=inp_cond['img_ids'], + txt=inp_cond['txt'], + txt_ids=inp_cond['txt_ids'], + vec=inp_cond['vec'], + controlnet=self.controlnet, + timesteps=timesteps, + guidance=guidance, + controlnet_cond=controlnet_image, + timestep_to_start_cfg=timestep_to_start_cfg, + neg_txt=neg_inp_cond['txt'], + neg_txt_ids=neg_inp_cond['txt_ids'], + neg_vec=neg_inp_cond['vec'], + true_gs=true_gs, + controlnet_gs=control_weight, + image_proj=image_proj, + neg_image_proj=neg_image_proj, + ip_scale=ip_scale, + neg_ip_scale=neg_ip_scale, + ) + else: + x = denoise( + self.model, + img=inp_cond['img'], + img_ids=inp_cond['img_ids'], + txt=inp_cond['txt'], + txt_ids=inp_cond['txt_ids'], + vec=inp_cond['vec'], + timesteps=timesteps, + guidance=guidance, + timestep_to_start_cfg=timestep_to_start_cfg, + neg_txt=neg_inp_cond['txt'], + neg_txt_ids=neg_inp_cond['txt_ids'], + neg_vec=neg_inp_cond['vec'], + true_gs=true_gs, + image_proj=image_proj, + neg_image_proj=neg_image_proj, + ip_scale=ip_scale, + neg_ip_scale=neg_ip_scale, + source_image=source_image, # spatial_condition source image + use_share_weight_referencenet=self.use_share_weight_referencenet, + single_img_ids=inp_cond['single_img_ids'] if self.use_share_weight_referencenet else None, + neg_single_img_ids=neg_inp_cond['single_img_ids'] if self.use_share_weight_referencenet else None, + single_block_refnet=self.single_block_refnet, + double_block_refnet=self.double_block_refnet, + ) + + if self.offload: + self.offload_model_to_cpu(self.model) + self.ae.decoder.to(x.device) + x = unpack(x.float(), height, width) + x = self.ae.decode(x) + self.offload_model_to_cpu(self.ae.decoder) + + x1 = x.clamp(-1, 1) + x1 = rearrange(x1[-1], "c h w -> h w c") + output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) + return output_img + + def offload_model_to_cpu(self, *models): + if not self.offload: return + for model in models: + model.cpu() + torch.cuda.empty_cache() + + +class XFluxSampler(XFluxPipeline): + def __init__(self, clip, t5, ae, model, device,controlnet_loaded=False,ip_loaded=False, spatial_condition=False, offload=False, clip_image_processor=None, image_encoder=None, improj=None, share_position_embedding=False, use_share_weight_referencenet=False, single_block_refnet=False, double_block_refnet=False): + self.clip = clip + self.t5 = t5 + self.ae = ae + self.model = model + self.model.eval() + self.device = device + self.controlnet_loaded = controlnet_loaded + self.ip_loaded = ip_loaded + self.offload = offload + self.clip_image_processor = clip_image_processor + self.image_encoder = image_encoder + self.improj = improj + self.spatial_condition = spatial_condition + self.share_position_embedding = share_position_embedding + self.use_share_weight_referencenet = use_share_weight_referencenet + self.single_block_refnet = single_block_refnet + self.double_block_refnet = double_block_refnet \ No newline at end of file